diff --git a/tests/test_fourier_base.py b/tests/test_fourier_base.py index c8a092c..d73d04c 100644 --- a/tests/test_fourier_base.py +++ b/tests/test_fourier_base.py @@ -183,13 +183,44 @@ def test_scale_to_filter_qlsi(): assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6) +def test_fft_dimensionality_consistency(): + """Compare using fft algorithms on 2d and 3d data.""" + image_3d = np.arange(1000).reshape(10, 10, 10) + image_2d = image_3d[0].copy() + + # fft with shift + fft_3d = np.fft.fftshift(np.fft.fft2(image_3d, axes=(-2, -1)), + axes=(-2, -1)) + fft_2d = np.fft.fftshift(np.fft.fft2(image_2d)) # old qpretrieve + assert fft_3d.shape == (10, 10, 10) + assert fft_2d.shape == (10, 10) + assert np.allclose(fft_3d[0], fft_2d, rtol=0, atol=1e-8) + + # ifftshift + fft_3d_shifted = np.fft.ifftshift(fft_3d, axes=(-2, -1)) + fft_2d_shifted = np.fft.ifftshift(fft_2d) # old qpretrieve + assert fft_3d_shifted.shape == (10, 10, 10) + assert fft_2d_shifted.shape == (10, 10) + assert np.allclose(fft_3d_shifted[0], fft_2d_shifted, rtol=0, atol=1e-8) + + # ifft + ifft_3d_shifted = np.fft.ifft2(fft_3d_shifted, axes=(-2, -1)) + ifft_2d_shifted = np.fft.ifft2(fft_2d_shifted) # old qpretrieve + assert ifft_3d_shifted.shape == (10, 10, 10) + assert ifft_2d_shifted.shape == (10, 10) + assert np.allclose(ifft_3d_shifted[0], ifft_2d_shifted, rtol=0, atol=1e-8) + + assert np.allclose(ifft_3d_shifted.real, image_3d, rtol=0, atol=1e-8) + assert np.allclose(ifft_2d_shifted.real, image_2d, rtol=0, atol=1e-8) + + def test_fft_comparison_FFTFilter(): image = np.arange(1000).reshape(10, 10, 10) ff_np = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False) ff_tw = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False) assert ff_np.fft_origin.shape == ff_tw.fft_origin.shape == (10, 10, 10) - assert np.allclose(ff_np.fft_origin, ff_np.fft_origin, rtol=0, atol=1e-8) + assert np.allclose(ff_np.fft_origin, ff_tw.fft_origin, rtol=0, atol=1e-8) assert np.allclose( np.fft.ifft2(np.fft.ifftshift(ff_np.fft_origin, axes=(-2, -1))).real, np.fft.ifft2(np.fft.ifftshift(ff_tw.fft_origin, axes=(-2, -1))).real,