From e427cc54cbe28a0fb120197367e6af4e77c2ef82 Mon Sep 17 00:00:00 2001 From: Eoghan O'Connell Date: Fri, 8 Nov 2024 17:34:37 +0100 Subject: [PATCH] enh: ensure ifft with padding works with 3D stack --- qpretrieve/fourier/base.py | 15 +++-- .../test_oah_from_qpimage_cupy.py | 58 ++++++++++++------- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/qpretrieve/fourier/base.py b/qpretrieve/fourier/base.py index aa8c097..f9806e8 100644 --- a/qpretrieve/fourier/base.py +++ b/qpretrieve/fourier/base.py @@ -245,7 +245,8 @@ def filter(self, filter_name: str, filter_size: float, fft_filtered = self.fft_origin * filt_array px = int(freq_pos[0] * self.shape[-2]) py = int(freq_pos[1] * self.shape[-1]) - fft_used = np.roll(np.roll(fft_filtered, -px, axis=-2), -py, axis=-1) + fft_used = np.roll(np.roll( + fft_filtered, -px, axis=-2), -py, axis=-1) if scale_to_filter: # Determine the size of the cropping region. # We compute the "radius" of the region, so we can @@ -263,16 +264,18 @@ def filter(self, filter_name: str, filter_size: float, fft_used = fft_used[:, cslice, cslice] field = self._ifft(np.fft.ifftshift(fft_used)) - if len(self.origin.shape) != 2: - # todo: this must be corrected - self.padding = False + if self.padding: # revert padding - sx, sy = self.origin.shape + sx, sy = self.origin.shape[-2:] if scale_to_filter: sx = int(np.ceil(sx * 2 * crad / osize)) sy = int(np.ceil(sy * 2 * crad / osize)) - field = field[:sx, :sy] + if len(fft_used.shape) == 2: + field = field[:sx, :sy] + elif len(fft_used.shape) == 3: + field = field[:, :sx, :sy] + if scale_to_filter: # Scale the absolute value of the field. This does not # have any influence on the phase, but on the amplitude. diff --git a/tests/test_cupy_gpu/test_oah_from_qpimage_cupy.py b/tests/test_cupy_gpu/test_oah_from_qpimage_cupy.py index 8a8a1f2..c412485 100644 --- a/tests/test_cupy_gpu/test_oah_from_qpimage_cupy.py +++ b/tests/test_cupy_gpu/test_oah_from_qpimage_cupy.py @@ -5,6 +5,27 @@ from qpretrieve.fourier import FFTFilterCupy3D, FFTFilterCupy, FFTFilterNumpy +def test_get_field_compare_FFTFilters(hologram): + data1 = hologram + + holo1 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterNumpy, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res1 = holo1.run_pipeline(**kwargs) + + holo1 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterCupy, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res2 = holo1.run_pipeline(**kwargs) + + assert res1.shape == (64, 64) + assert res2.shape == (64, 64) + assert not np.all(res1 == res2) + assert np.allclose(res1, res2) + + def test_get_field_cupy3d(hologram): data1 = hologram data_rp = np.array([data1, data1, data1, data1, data1]) @@ -25,30 +46,25 @@ def test_get_field_cupy3d(hologram): assert not np.all(res1[0] == res2) - # import matplotlib.pyplot as plt - # fig, axes = plt.subplots(3, 1) - # ax1, ax2, ax3 = axes - # ax1.imshow(np.abs(res1[0])) - # ax2.imshow(np.abs(res2)) - # ax3.imshow(np.abs(res2)-np.abs(res1[0])) - # plt.show() - -def test_get_field_compare_FFTFilters(hologram): +def test_get_field_cupy3d_scale_to_filter(hologram): data1 = hologram + data_rp = np.array([data1, data1, data1, data1, data1]) - holo1 = qpretrieve.OffAxisHologram(data1, - fft_interface=FFTFilterNumpy, - padding=False) - kwargs = dict(filter_name="disk", filter_size=1 / 3) + holo1 = qpretrieve.OffAxisHologram(data_rp, + fft_interface=FFTFilterCupy3D, + padding=True) + kwargs = dict(filter_name="disk", filter_size=1 / 3, + scale_to_filter=True) res1 = holo1.run_pipeline(**kwargs) - assert res1.shape == (64, 64) - holo1 = qpretrieve.OffAxisHologram(data1, - fft_interface=FFTFilterCupy, - padding=False) - kwargs = dict(filter_name="disk", filter_size=1 / 3) - res2 = holo1.run_pipeline(**kwargs) - assert res2.shape == (64, 64) + holo2 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterNumpy, + padding=True) + kwargs = dict(filter_name="disk", filter_size=1 / 3, + scale_to_filter=True) + res2 = holo2.run_pipeline(**kwargs) - assert not np.all(res1 == res2) + assert res1.shape == (5, 18, 18) + assert res2.shape == (18, 18) + assert np.allclose(res1[0], res2)