Skip to content

Commit

Permalink
tests: ensure old and new uses of fft algorithms are consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 22, 2025
1 parent da0f046 commit 10c59b6
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion tests/test_fourier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,44 @@ def test_scale_to_filter_qlsi():
assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6)


def test_fft_dimensionality_consistency():
"""Compare using fft algorithms on 2d and 3d data."""
image_3d = np.arange(1000).reshape(10, 10, 10)
image_2d = image_3d[0].copy()

# fft with shift
fft_3d = np.fft.fftshift(np.fft.fft2(image_3d, axes=(-2, -1)),
axes=(-2, -1))
fft_2d = np.fft.fftshift(np.fft.fft2(image_2d)) # old qpretrieve
assert fft_3d.shape == (10, 10, 10)
assert fft_2d.shape == (10, 10)
assert np.allclose(fft_3d[0], fft_2d, rtol=0, atol=1e-8)

# ifftshift
fft_3d_shifted = np.fft.ifftshift(fft_3d, axes=(-2, -1))
fft_2d_shifted = np.fft.ifftshift(fft_2d) # old qpretrieve
assert fft_3d_shifted.shape == (10, 10, 10)
assert fft_2d_shifted.shape == (10, 10)
assert np.allclose(fft_3d_shifted[0], fft_2d_shifted, rtol=0, atol=1e-8)

# ifft
ifft_3d_shifted = np.fft.ifft2(fft_3d_shifted, axes=(-2, -1))
ifft_2d_shifted = np.fft.ifft2(fft_2d_shifted) # old qpretrieve
assert ifft_3d_shifted.shape == (10, 10, 10)
assert ifft_2d_shifted.shape == (10, 10)
assert np.allclose(ifft_3d_shifted[0], ifft_2d_shifted, rtol=0, atol=1e-8)

assert np.allclose(ifft_3d_shifted.real, image_3d, rtol=0, atol=1e-8)
assert np.allclose(ifft_2d_shifted.real, image_2d, rtol=0, atol=1e-8)


def test_fft_comparison_FFTFilter():
image = np.arange(1000).reshape(10, 10, 10)
ff_np = fourier.FFTFilterNumpy(image, subtract_mean=False, padding=False)
ff_tw = fourier.FFTFilterPyFFTW(image, subtract_mean=False, padding=False)
assert ff_np.fft_origin.shape == ff_tw.fft_origin.shape == (10, 10, 10)

assert np.allclose(ff_np.fft_origin, ff_np.fft_origin, rtol=0, atol=1e-8)
assert np.allclose(ff_np.fft_origin, ff_tw.fft_origin, rtol=0, atol=1e-8)
assert np.allclose(
np.fft.ifft2(np.fft.ifftshift(ff_np.fft_origin, axes=(-2, -1))).real,
np.fft.ifft2(np.fft.ifftshift(ff_tw.fft_origin, axes=(-2, -1))).real,
Expand Down

0 comments on commit 10c59b6

Please sign in to comment.