diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d04fca2..21b78b7 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -10,7 +10,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.10'] + python-version: ['3.10', '3.11'] os: [macos-latest, ubuntu-latest, windows-latest] steps: diff --git a/examples/fft_options.py b/examples/fft_options.py new file mode 100644 index 0000000..1a0821b --- /dev/null +++ b/examples/fft_options.py @@ -0,0 +1,51 @@ +"""Fourier Transform options available + +This example visualizes the different backends and packages available to the +user for performing Fourier transforms. +""" +import matplotlib.pylab as plt +import numpy as np +import qpretrieve +from skimage.restoration import unwrap_phase + +# load the experimental data +edata = np.load("./data/hologram_cell.npz") + +# get the available fft interfaces +interfaces_available = qpretrieve.fourier.get_available_interfaces() + +prange = (-1, 5) +frange = (0, 12) + +results = {} + +for fft_interface in interfaces_available: + holo = qpretrieve.OffAxisHologram(data=edata["data"], + fft_interface=fft_interface) + holo.run_pipeline(filter_name="disk", filter_size=1/2) + bg = qpretrieve.OffAxisHologram(data=edata["bg_data"]) + bg.process_like(holo) + phase = unwrap_phase(holo.phase - bg.phase) + mask = np.log(1 + np.abs(holo.fft_filtered)) + results[fft_interface.__name__] = mask, phase + +num_filters = len(results) + +# plot the properties of `qpi` +fig = plt.figure(figsize=(8, 22)) + +for row, name in enumerate(results): + ax1 = plt.subplot(num_filters, 2, 2*row+1) + ax1.set_title(name, loc="left") + ax1.imshow(results[name][0], vmin=frange[0], vmax=frange[1]) + + ax2 = plt.subplot(num_filters, 2, 2*row+2) + map2 = ax2.imshow(results[name][1], cmap="coolwarm", + vmin=prange[0], vmax=prange[1]) + plt.colorbar(map2, ax=ax2, fraction=.046, pad=0.02, label="phase [rad]") + + ax1.axis("off") + ax2.axis("off") + +plt.tight_layout() +plt.show() diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..6ccafc3 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +matplotlib diff --git a/qpretrieve/filter.py b/qpretrieve/filter.py index aa8302c..43de0cc 100644 --- a/qpretrieve/filter.py +++ b/qpretrieve/filter.py @@ -104,8 +104,8 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): # TODO: avoid the np.roll, instead use the indices directly alpha = 0.1 rsize = int(min(fx.size, fy.size) * filter_size) * 2 - tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1) - tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1) + tukey_window_x = signal.windows.tukey(rsize, alpha=alpha).reshape(-1, 1) + tukey_window_y = signal.windows.tukey(rsize, alpha=alpha).reshape(1, -1) tukey = tukey_window_x * tukey_window_y base = np.zeros(fft_shape) s1 = (np.array(fft_shape) - rsize) // 2 diff --git a/qpretrieve/fourier/__init__.py b/qpretrieve/fourier/__init__.py index 62d479c..b9252ef 100644 --- a/qpretrieve/fourier/__init__.py +++ b/qpretrieve/fourier/__init__.py @@ -2,6 +2,7 @@ import warnings from .ff_numpy import FFTFilterNumpy +from .ff_scipy import FFTFilterScipy try: from .ff_pyfftw import FFTFilterPyFFTW @@ -11,6 +12,20 @@ PREFERRED_INTERFACE = None +def get_available_interfaces(): + """Return a list of available FFT algorithms""" + interfaces = [ + FFTFilterPyFFTW, + FFTFilterNumpy, + FFTFilterScipy, + ] + interfaces_available = [] + for interface in interfaces: + if interface is not None and interface.is_available: + interfaces_available.append(interface) + return interfaces_available + + def get_best_interface(): """Return the fastest refocusing interface available diff --git a/qpretrieve/fourier/base.py b/qpretrieve/fourier/base.py index 0c430ac..36e758b 100644 --- a/qpretrieve/fourier/base.py +++ b/qpretrieve/fourier/base.py @@ -70,8 +70,10 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): else: # convert integer-arrays to floating point arrays dtype = float + if not copy: + copy = None # numpy v2.x behaviour requires asarray with copy=False data_ed = np.array(data, dtype=dtype, copy=copy) - #: original data (with subtracted mean) +#: original data (with subtracted mean) self.origin = data_ed #: whether padding is enabled self.padding = padding diff --git a/qpretrieve/fourier/ff_scipy.py b/qpretrieve/fourier/ff_scipy.py new file mode 100644 index 0000000..81736a2 --- /dev/null +++ b/qpretrieve/fourier/ff_scipy.py @@ -0,0 +1,30 @@ +import scipy as sp + + +from .base import FFTFilter + + +class FFTFilterScipy(FFTFilter): + """Wraps the scipy Fourier transform + """ + # always available, because scipy is a dependency + is_available = True + + def _init_fft(self, data): + """Perform initial Fourier transform of the input data + + Parameters + ---------- + data: 2d real-valued np.ndarray + Input field to be refocused + + Returns + ------- + fft_fdata: 2d complex-valued ndarray + Fourier transform `data` + """ + return sp.fft.fft2(data) + + def _ifft(self, data): + """Perform inverse Fourier transform""" + return sp.fft.ifft2(data) diff --git a/qpretrieve/interfere/base.py b/qpretrieve/interfere/base.py index f1a637a..116fe6c 100644 --- a/qpretrieve/interfere/base.py +++ b/qpretrieve/interfere/base.py @@ -2,7 +2,8 @@ import numpy as np -from ..fourier import get_best_interface +from ..fourier import get_best_interface, get_available_interfaces +from ..fourier.base import FFTFilter class BaseInterferogram(ABC): @@ -15,7 +16,8 @@ class BaseInterferogram(ABC): "invert_phase": False, } - def __init__(self, data, subtract_mean=True, padding=2, copy=True, + def __init__(self, data, fft_interface: FFTFilter = None, + subtract_mean=True, padding=2, copy=True, **pipeline_kws): """ Parameters @@ -38,12 +40,22 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True, Any additional keyword arguments for :func:`run_pipeline` as defined in :const:`default_pipeline_kws`. """ - ff_iface = get_best_interface() + if fft_interface == 'auto' or fft_interface is None: + self.ff_iface = get_best_interface() + else: + if fft_interface in get_available_interfaces(): + self.ff_iface = fft_interface + else: + raise ValueError( + f"User-chosen FFT Interface '{fft_interface}' is not available. " + f"The available interfaces are: {get_available_interfaces()}.\n" + f"You can use `fft_interface='auto'` to get the best " + f"available interface.") if len(data.shape) == 3: # take the first slice (we have alpha or RGB information) data = data[:, :, 0] #: qpretrieve Fourier transform interface class - self.fft = ff_iface(data=data, + self.fft = self.ff_iface(data=data, subtract_mean=subtract_mean, padding=padding, copy=copy) diff --git a/tests/test_fourier_scipy.py b/tests/test_fourier_scipy.py new file mode 100644 index 0000000..2ccc379 --- /dev/null +++ b/tests/test_fourier_scipy.py @@ -0,0 +1,15 @@ +import numpy as np +import scipy as sp + +from qpretrieve import fourier + + +def test_fft_correct(): + image = np.arange(100).reshape(10, 10) + ff = fourier.FFTFilterScipy(image, subtract_mean=False, padding=False) + assert np.allclose( + sp.fft.ifft2(np.fft.ifftshift(ff.fft_origin)).real, + image, + rtol=0, + atol=1e-8 + ) diff --git a/tests/test_interfere_base.py b/tests/test_interfere_base.py new file mode 100644 index 0000000..2fb622f --- /dev/null +++ b/tests/test_interfere_base.py @@ -0,0 +1,12 @@ +import numpy as np + +import qpretrieve + + +def test_interfere_base_best_interface(): + edata = np.load("./data/hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, qpretrieve.fourier.ff_numpy.FFTFilterNumpy)