Skip to content

Commit

Permalink
enh: ensure ifft with padding works with 3D stack
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Nov 8, 2024
1 parent d2e2f61 commit e427cc5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
15 changes: 9 additions & 6 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def filter(self, filter_name: str, filter_size: float,
fft_filtered = self.fft_origin * filt_array
px = int(freq_pos[0] * self.shape[-2])
py = int(freq_pos[1] * self.shape[-1])
fft_used = np.roll(np.roll(fft_filtered, -px, axis=-2), -py, axis=-1)
fft_used = np.roll(np.roll(
fft_filtered, -px, axis=-2), -py, axis=-1)
if scale_to_filter:
# Determine the size of the cropping region.
# We compute the "radius" of the region, so we can
Expand All @@ -263,16 +264,18 @@ def filter(self, filter_name: str, filter_size: float,
fft_used = fft_used[:, cslice, cslice]

field = self._ifft(np.fft.ifftshift(fft_used))
if len(self.origin.shape) != 2:
# todo: this must be corrected
self.padding = False

if self.padding:
# revert padding
sx, sy = self.origin.shape
sx, sy = self.origin.shape[-2:]
if scale_to_filter:
sx = int(np.ceil(sx * 2 * crad / osize))
sy = int(np.ceil(sy * 2 * crad / osize))
field = field[:sx, :sy]
if len(fft_used.shape) == 2:
field = field[:sx, :sy]
elif len(fft_used.shape) == 3:
field = field[:, :sx, :sy]

if scale_to_filter:
# Scale the absolute value of the field. This does not
# have any influence on the phase, but on the amplitude.
Expand Down
58 changes: 37 additions & 21 deletions tests/test_cupy_gpu/test_oah_from_qpimage_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@
from qpretrieve.fourier import FFTFilterCupy3D, FFTFilterCupy, FFTFilterNumpy


def test_get_field_compare_FFTFilters(hologram):
data1 = hologram

holo1 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterNumpy,
padding=False)
kwargs = dict(filter_name="disk", filter_size=1 / 3)
res1 = holo1.run_pipeline(**kwargs)

holo1 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterCupy,
padding=False)
kwargs = dict(filter_name="disk", filter_size=1 / 3)
res2 = holo1.run_pipeline(**kwargs)

assert res1.shape == (64, 64)
assert res2.shape == (64, 64)
assert not np.all(res1 == res2)
assert np.allclose(res1, res2)


def test_get_field_cupy3d(hologram):
data1 = hologram
data_rp = np.array([data1, data1, data1, data1, data1])
Expand All @@ -25,30 +46,25 @@ def test_get_field_cupy3d(hologram):

assert not np.all(res1[0] == res2)

# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(3, 1)
# ax1, ax2, ax3 = axes
# ax1.imshow(np.abs(res1[0]))
# ax2.imshow(np.abs(res2))
# ax3.imshow(np.abs(res2)-np.abs(res1[0]))
# plt.show()


def test_get_field_compare_FFTFilters(hologram):
def test_get_field_cupy3d_scale_to_filter(hologram):
data1 = hologram
data_rp = np.array([data1, data1, data1, data1, data1])

holo1 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterNumpy,
padding=False)
kwargs = dict(filter_name="disk", filter_size=1 / 3)
holo1 = qpretrieve.OffAxisHologram(data_rp,
fft_interface=FFTFilterCupy3D,
padding=True)
kwargs = dict(filter_name="disk", filter_size=1 / 3,
scale_to_filter=True)
res1 = holo1.run_pipeline(**kwargs)
assert res1.shape == (64, 64)

holo1 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterCupy,
padding=False)
kwargs = dict(filter_name="disk", filter_size=1 / 3)
res2 = holo1.run_pipeline(**kwargs)
assert res2.shape == (64, 64)
holo2 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterNumpy,
padding=True)
kwargs = dict(filter_name="disk", filter_size=1 / 3,
scale_to_filter=True)
res2 = holo2.run_pipeline(**kwargs)

assert not np.all(res1 == res2)
assert res1.shape == (5, 18, 18)
assert res2.shape == (18, 18)
assert np.allclose(res1[0], res2)

0 comments on commit e427cc5

Please sign in to comment.