Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- import argparse
- # Import CuPy for GPU acceleration (if available)
- try:
- import cupy as cp
- GPU_AVAILABLE = True
- except ImportError:
- GPU_AVAILABLE = False
- # Helper Functions
- def generate_cosine_wave(freq, sample_rate, duration, amplitude=1.0, phase=0):
- """Generate a 1D cosine wave."""
- t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
- signal = amplitude * np.cos(2 * np.pi * freq * t + phase)
- return t, signal
- def generate_2d_cosine(fx, fy, size):
- """Generate a 2D cosine pattern."""
- x = np.arange(size)
- y = np.arange(size)
- X, Y = np.meshgrid(x, y)
- signal = np.cos(2 * np.pi * fx * X / size) + np.cos(2 * np.pi * fy * Y / size)
- return signal
- def plot_and_save_fft_result(freq, fft_result, title="FFT Result", filename="fft_result.png"):
- """Plot and optionally save the magnitude of the FFT result."""
- plt.figure(figsize=(10, 4))
- plt.plot(freq, np.abs(fft_result))
- plt.title(title)
- plt.xlabel("Frequency (Hz)")
- plt.ylabel("Amplitude")
- plt.savefig(filename)
- plt.show()
- # GPU FFT using CuPy
- def gpu_fft_1d(signal):
- if not GPU_AVAILABLE:
- raise RuntimeError("CuPy is not available. Install CuPy for GPU support.")
- signal_gpu = cp.array(signal)
- fft_result_gpu = cp.fft.fft(signal_gpu)
- return cp.asnumpy(fft_result_gpu)
- def time_logger(func):
- """Decorator to log the execution time of functions."""
- def wrapper(*args, **kwargs):
- start = time.time()
- result = func(*args, **kwargs)
- end = time.time()
- print(f"{func.__name__} executed in {end - start:.4f} seconds")
- return result
- return wrapper
- @time_logger
- def precision_test():
- print("\n--- Precision Test ---")
- t, signal = generate_cosine_wave(10, 1000, 1.0, amplitude=1e-6)
- fft_freq, fft_result = test_fft_1d(signal, 1000)
- plot_and_save_fft_result(fft_freq, fft_result, "FFT of Small Amplitude Signal", "precision_test.png")
- error = validate_inverse_fft(signal)
- print(f"Inverse FFT Reconstruction Error: {error}")
- @time_logger
- def performance_scaling_test():
- print("\n--- Performance Scaling Test ---")
- sizes = [128, 256, 512, 1024, 2048, 4096, 8192]
- times = []
- for size in sizes:
- signal = np.random.rand(size, size).astype(np.float32)
- start = time.time()
- np.fft.fft2(signal)
- times.append(time.time() - start)
- plt.plot(sizes, times, marker='o')
- plt.title("Performance Scaling of 2D FFT")
- plt.xlabel("Matrix Size")
- plt.ylabel("Time (seconds)")
- plt.savefig("performance_scaling_test.png")
- plt.show()
- @time_logger
- def noisy_signal_test():
- print("\n--- Noisy Signal Test ---")
- t, signal = generate_cosine_wave(10, 1000, 1.0)
- noisy_signal = signal + 0.5 * np.random.normal(size=len(signal))
- fft_freq, fft_result = test_fft_1d(noisy_signal, 1000)
- plot_and_save_fft_result(fft_freq, fft_result, "FFT of Noisy Signal", "noisy_signal_test.png")
- @time_logger
- def windowing_test():
- print("\n--- Windowing Test ---")
- t, signal = generate_cosine_wave(10, 1000, 1.0)
- window = np.hamming(len(signal))
- windowed_signal = signal * window
- fft_freq, fft_no_window = test_fft_1d(signal, 1000)
- fft_freq, fft_with_window = test_fft_1d(windowed_signal, 1000)
- plt.plot(fft_freq, np.abs(fft_no_window), label="Without Window")
- plt.plot(fft_freq, np.abs(fft_with_window), label="With Hamming Window")
- plt.title("Effect of Windowing on FFT")
- plt.xlabel("Frequency (Hz)")
- plt.ylabel("Amplitude")
- plt.legend()
- plt.savefig("windowing_test.png")
- plt.show()
- @time_logger
- def cosine_2d_test():
- print("\n--- 2D Cosine Pattern Test ---")
- signal = generate_2d_cosine(10, 15, 256)
- fft_result = np.fft.fftshift(np.fft.fft2(signal))
- plt.subplot(1, 2, 1)
- plt.imshow(signal, cmap='gray')
- plt.title("2D Cosine Pattern")
- plt.subplot(1, 2, 2)
- plt.imshow(np.log(np.abs(fft_result) + 1), cmap='gray')
- plt.title("FFT Magnitude Spectrum")
- plt.savefig("cosine_2d_test.png")
- plt.show()
- @time_logger
- def high_frequency_test():
- print("\n--- High-Frequency Test ---")
- t, signal = generate_cosine_wave(400, 1000, 1.0)
- fft_freq, fft_result = test_fft_1d(signal, 1000)
- plot_and_save_fft_result(fft_freq, fft_result, "FFT of High-Frequency Signal", "high_frequency_test.png")
- @time_logger
- def batch_processing_test():
- print("\n--- Batch Processing Test ---")
- batch_size = 10
- signals = [generate_cosine_wave(10 + i, 1000, 1.0)[1] for i in range(batch_size)]
- batch_fft = [np.fft.fft(signal) for signal in signals]
- fft_freq = np.fft.fftfreq(len(signals[0]), d=1/1000)
- plt.plot(fft_freq, np.abs(batch_fft[0]))
- plt.title("FFT of First Signal in Batch")
- plt.xlabel("Frequency (Hz)")
- plt.ylabel("Amplitude")
- plt.savefig("batch_processing_test.png")
- plt.show()
- def test_fft_1d(signal, sample_rate):
- """Perform FFT and return frequency and transformed signal."""
- fft_result = np.fft.fft(signal)
- fft_freq = np.fft.fftfreq(len(signal), d=1/sample_rate)
- return fft_freq, fft_result
- def validate_inverse_fft(signal):
- """Validate that inverse FFT reconstructs the original signal."""
- fft_result = np.fft.fft(signal)
- reconstructed_signal = np.fft.ifft(fft_result)
- error = np.max(np.abs(signal - reconstructed_signal))
- return error
- def main():
- parser = argparse.ArgumentParser(description="FFT Testing Suite")
- parser.add_argument("--test", type=str, help="Select test to run", default="all")
- args = parser.parse_args()
- if args.test == "all":
- precision_test()
- performance_scaling_test()
- noisy_signal_test()
- windowing_test()
- cosine_2d_test()
- high_frequency_test()
- batch_processing_test()
- elif args.test == "precision":
- precision_test()
- elif args.test == "scaling":
- performance_scaling_test()
- elif args.test == "noisy":
- noisy_signal_test()
- elif args.test == "windowing":
- windowing_test()
- elif args.test == "2d":
- cosine_2d_test()
- elif args.test == "highfreq":
- high_frequency_test()
- elif args.test == "batch":
- batch_processing_test()
- else:
- print(f"Unknown test: {args.test}")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement