Skip to content

Commit

Permalink
enh: add scipy fft interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Aug 26, 2024
1 parent 4446d0d commit f09a64f
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.10']
python-version: ['3.10', '3.11']
os: [macos-latest, ubuntu-latest, windows-latest]

steps:
Expand Down
51 changes: 51 additions & 0 deletions examples/fft_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Fourier Transform options available
This example visualizes the different backends and packages available to the
user for performing Fourier transforms.
"""
import matplotlib.pylab as plt
import numpy as np
import qpretrieve
from skimage.restoration import unwrap_phase

# load the experimental data
edata = np.load("./data/hologram_cell.npz")

# get the available fft interfaces
interfaces_available = qpretrieve.fourier.get_available_interfaces()

prange = (-1, 5)
frange = (0, 12)

results = {}

for fft_interface in interfaces_available:
holo = qpretrieve.OffAxisHologram(data=edata["data"],
fft_interface=fft_interface)
holo.run_pipeline(filter_name="disk", filter_size=1/2)
bg = qpretrieve.OffAxisHologram(data=edata["bg_data"])
bg.process_like(holo)
phase = unwrap_phase(holo.phase - bg.phase)
mask = np.log(1 + np.abs(holo.fft_filtered))
results[fft_interface.__name__] = mask, phase

num_filters = len(results)

# plot the properties of `qpi`
fig = plt.figure(figsize=(8, 22))

for row, name in enumerate(results):
ax1 = plt.subplot(num_filters, 2, 2*row+1)
ax1.set_title(name, loc="left")
ax1.imshow(results[name][0], vmin=frange[0], vmax=frange[1])

ax2 = plt.subplot(num_filters, 2, 2*row+2)
map2 = ax2.imshow(results[name][1], cmap="coolwarm",
vmin=prange[0], vmax=prange[1])
plt.colorbar(map2, ax=ax2, fraction=.046, pad=0.02, label="phase [rad]")

ax1.axis("off")
ax2.axis("off")

plt.tight_layout()
plt.show()
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
matplotlib
4 changes: 2 additions & 2 deletions qpretrieve/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
# TODO: avoid the np.roll, instead use the indices directly
alpha = 0.1
rsize = int(min(fx.size, fy.size) * filter_size) * 2
tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1)
tukey_window_x = signal.windows.tukey(rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.windows.tukey(rsize, alpha=alpha).reshape(1, -1)
tukey = tukey_window_x * tukey_window_y
base = np.zeros(fft_shape)
s1 = (np.array(fft_shape) - rsize) // 2
Expand Down
15 changes: 15 additions & 0 deletions qpretrieve/fourier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

from .ff_numpy import FFTFilterNumpy
from .ff_scipy import FFTFilterScipy

try:
from .ff_pyfftw import FFTFilterPyFFTW
Expand All @@ -11,6 +12,20 @@
PREFERRED_INTERFACE = None


def get_available_interfaces():
"""Return a list of available FFT algorithms"""
interfaces = [
FFTFilterPyFFTW,
FFTFilterNumpy,
FFTFilterScipy,
]
interfaces_available = []
for interface in interfaces:
if interface is not None and interface.is_available:
interfaces_available.append(interface)
return interfaces_available


def get_best_interface():
"""Return the fastest refocusing interface available
Expand Down
4 changes: 3 additions & 1 deletion qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
else:
# convert integer-arrays to floating point arrays
dtype = float
if not copy:
copy = None # numpy v2.x behaviour requires asarray with copy=False
data_ed = np.array(data, dtype=dtype, copy=copy)
#: original data (with subtracted mean)
#: original data (with subtracted mean)
self.origin = data_ed
#: whether padding is enabled
self.padding = padding
Expand Down
30 changes: 30 additions & 0 deletions qpretrieve/fourier/ff_scipy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import scipy as sp


from .base import FFTFilter


class FFTFilterScipy(FFTFilter):
"""Wraps the scipy Fourier transform
"""
# always available, because scipy is a dependency
is_available = True

def _init_fft(self, data):
"""Perform initial Fourier transform of the input data
Parameters
----------
data: 2d real-valued np.ndarray
Input field to be refocused
Returns
-------
fft_fdata: 2d complex-valued ndarray
Fourier transform `data`
"""
return sp.fft.fft2(data)

def _ifft(self, data):
"""Perform inverse Fourier transform"""
return sp.fft.ifft2(data)
20 changes: 16 additions & 4 deletions qpretrieve/interfere/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np

from ..fourier import get_best_interface
from ..fourier import get_best_interface, get_available_interfaces
from ..fourier.base import FFTFilter


class BaseInterferogram(ABC):
Expand All @@ -15,7 +16,8 @@ class BaseInterferogram(ABC):
"invert_phase": False,
}

def __init__(self, data, subtract_mean=True, padding=2, copy=True,
def __init__(self, data, fft_interface: FFTFilter = None,
subtract_mean=True, padding=2, copy=True,
**pipeline_kws):
"""
Parameters
Expand All @@ -38,12 +40,22 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True,
Any additional keyword arguments for :func:`run_pipeline`
as defined in :const:`default_pipeline_kws`.
"""
ff_iface = get_best_interface()
if fft_interface == 'auto' or fft_interface is None:
self.ff_iface = get_best_interface()
else:
if fft_interface in get_available_interfaces():
self.ff_iface = fft_interface
else:
raise ValueError(
f"User-chosen FFT Interface '{fft_interface}' is not available. "
f"The available interfaces are: {get_available_interfaces()}.\n"
f"You can use `fft_interface='auto'` to get the best "
f"available interface.")
if len(data.shape) == 3:
# take the first slice (we have alpha or RGB information)
data = data[:, :, 0]
#: qpretrieve Fourier transform interface class
self.fft = ff_iface(data=data,
self.fft = self.ff_iface(data=data,
subtract_mean=subtract_mean,
padding=padding,
copy=copy)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_fourier_scipy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
import scipy as sp

from qpretrieve import fourier


def test_fft_correct():
image = np.arange(100).reshape(10, 10)
ff = fourier.FFTFilterScipy(image, subtract_mean=False, padding=False)
assert np.allclose(
sp.fft.ifft2(np.fft.ifftshift(ff.fft_origin)).real,
image,
rtol=0,
atol=1e-8
)
12 changes: 12 additions & 0 deletions tests/test_interfere_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np

import qpretrieve


def test_interfere_base_best_interface():
edata = np.load("./data/hologram_cell.npz")

holo = qpretrieve.OffAxisHologram(data=edata["data"])
assert holo.ff_iface.is_available
assert issubclass(holo.ff_iface, qpretrieve.fourier.base.FFTFilter)
assert issubclass(holo.ff_iface, qpretrieve.fourier.ff_numpy.FFTFilterNumpy)

0 comments on commit f09a64f

Please sign in to comment.