diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d04fca2..12b1bdb 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -10,7 +10,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.10'] + python-version: ['3.10', '3.11'] os: [macos-latest, ubuntu-latest, windows-latest] steps: @@ -39,7 +39,8 @@ jobs: pip freeze - name: Test with pytest run: | - coverage run --source=qpretrieve -m pytest tests + # ignore the cupy imports, as we don't have a gpu-enabled pipeline setup + coverage run --source=qpretrieve -m pytest tests --ignore=tests/test_cupy_gpu - name: Lint with flake8 run: | flake8 . diff --git a/docs/README.md b/docs/README.md index 694748f..e8441e1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,6 +6,7 @@ To install the requirements for building the documentation, run To compile the documentation, run + cd docs sphinx-build . _build diff --git a/docs/requirements.txt b/docs/requirements.txt index 15408de..87e4bc3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ -sphinx==4.3.0 -sphinxcontrib.bibtex>=2.0 -sphinx_rtd_theme==1.0 - +sphinx +sphinxcontrib.bibtex +sphinx_rtd_theme diff --git a/docs/sec_code_reference.rst b/docs/sec_code_reference.rst index dcf88ec..422e3dc 100644 --- a/docs/sec_code_reference.rst +++ b/docs/sec_code_reference.rst @@ -23,7 +23,6 @@ Fourier transform methods ========================= .. _sec_code_fourier_numpy: - Numpy ----- .. automodule:: qpretrieve.fourier.ff_numpy @@ -31,13 +30,13 @@ Numpy :inherited-members: .. _sec_code_fourier_pyfftw: - PyFFTW ------ .. automodule:: qpretrieve.fourier.ff_pyfftw :members: :inherited-members: + .. _sec_code_ifer: Interference image analysis diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..6ccafc3 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +matplotlib diff --git a/qpretrieve/data_input.py b/qpretrieve/data_input.py new file mode 100644 index 0000000..7eef4d1 --- /dev/null +++ b/qpretrieve/data_input.py @@ -0,0 +1,84 @@ +import numpy as np +import warnings + +allowed_data_formats = [ + "rgb", + "rgba", + "3d", + "2d", +] + + +def check_data_input_format(data_input): + """Figure out what data input is provided.""" + if len(data_input.shape) == 3: + if data_input.shape[-1] in [1, 2, 3]: + # take the first slice (we have alpha or RGB information) + data, data_format = _convert_rgb_to_3d(data_input) + elif data_input.shape[-1] == 4: + # take the first slice (we have alpha or RGB information) + data, data_format = _convert_rgba_to_3d(data_input) + else: + # we have a 3D image stack (z, y, x) + data, data_format = data_input, "3d" + elif len(data_input.shape) == 2: + # we have a 2D image (y, x). convert to (z, y, z) + data, data_format = _convert_2d_to_3d(data_input) + else: + raise ValueError(f"data_input shape must be 2d or 3d, " + f"got shape {data_input.shape}.") + return data.copy(), data_format + + +def revert_to_data_input_format(data_format, field): + """Convert the outputted field shape to the original input shape, + for user convenience.""" + assert data_format in allowed_data_formats + assert len(field.shape) == 3, "the field should be 3d" + field = field.copy() + if data_format == "rgb": + field = _revert_3d_to_rgb(field) + elif data_format == "rgba": + field = _revert_3d_to_rgba(field) + elif data_format == "3d": + field = field + else: + field = _revert_3d_to_2d(field) + return field + + +def _convert_rgb_to_3d(data_input): + data = data_input[:, :, 0] + data = data[np.newaxis, :, :] + data_format = "rgb" + warnings.warn(f"Format of input data detected as {data_format}. " + f"The first channel will be used for processing") + return data, data_format + + +def _convert_rgba_to_3d(data_input): + data, _ = _convert_rgb_to_3d(data_input) + data_format = "rgba" + return data, data_format + + +def _convert_2d_to_3d(data_input): + data = data_input[np.newaxis, :, :] + data_format = "2d" + return data, data_format + + +def _revert_3d_to_rgb(data_input): + data = data_input[0] + data = np.dstack((data, data, data)) + return data + + +def _revert_3d_to_rgba(data_input): + data = data_input[0] + data = np.dstack((data, data, data, np.ones_like(data))) + return data + + +def _revert_3d_to_2d(data_input): + return data_input[0] diff --git a/qpretrieve/filter.py b/qpretrieve/filter.py index aa8302c..d14c0eb 100644 --- a/qpretrieve/filter.py +++ b/qpretrieve/filter.py @@ -38,9 +38,9 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): and must be between 0 and `max(fft_shape)/2` freq_pos: tuple of floats The position of the filter in frequency coordinates as - returned by :func:`nunpy.fft.fftfreq`. + returned by :func:`numpy.fft.fftfreq`. fft_shape: tuple of int - The shape of the Fourier transformed image for which the + The shape of the Fourier transformed image (2d) for which the filter will be applied. The shape must be squared (two identical integers). @@ -104,8 +104,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): # TODO: avoid the np.roll, instead use the indices directly alpha = 0.1 rsize = int(min(fx.size, fy.size) * filter_size) * 2 - tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1) - tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1) + tukey_window_x = signal.windows.tukey( + rsize, alpha=alpha).reshape(-1, 1) + tukey_window_y = signal.windows.tukey( + rsize, alpha=alpha).reshape(1, -1) tukey = tukey_window_x * tukey_window_y base = np.zeros(fft_shape) s1 = (np.array(fft_shape) - rsize) // 2 diff --git a/qpretrieve/fourier/__init__.py b/qpretrieve/fourier/__init__.py index 62d479c..bfb4143 100644 --- a/qpretrieve/fourier/__init__.py +++ b/qpretrieve/fourier/__init__.py @@ -11,6 +11,19 @@ PREFERRED_INTERFACE = None +def get_available_interfaces(): + """Return a list of available FFT algorithms""" + interfaces = [ + FFTFilterPyFFTW, + FFTFilterNumpy, + ] + interfaces_available = [] + for interface in interfaces: + if interface is not None and interface.is_available: + interfaces_available.append(interface) + return interfaces_available + + def get_best_interface(): """Return the fastest refocusing interface available diff --git a/qpretrieve/fourier/base.py b/qpretrieve/fourier/base.py index 0c430ac..e6a13c2 100644 --- a/qpretrieve/fourier/base.py +++ b/qpretrieve/fourier/base.py @@ -6,6 +6,8 @@ import numpy as np from .. import filter +from ..utils import padding_3d, mean_3d +from ..data_input import check_data_input_format class FFTCache: @@ -35,12 +37,19 @@ def cleanup(key): class FFTFilter(ABC): - def __init__(self, data, subtract_mean=True, padding=2, copy=True): + def __init__(self, + data: np.ndarray, + subtract_mean: bool = True, + padding: int = 2, + copy: bool = True): r""" Parameters ---------- - data: 2d real-valued np.ndarray - The experimental input image + data + The experimental input real-valued image. Allowed input shapes are: + - 2d (y, x) + - 3d (z, y, x) + - 3d rgb (y, x, 3) or rgba (y, x, 4) subtract_mean: bool If True, subtract the mean of `data` before performing the Fourier transform. This setting is recommended as it @@ -70,9 +79,15 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): else: # convert integer-arrays to floating point arrays dtype = float + if not copy: + # numpy v2.x behaviour requires asarray with copy=False + copy = None data_ed = np.array(data, dtype=dtype, copy=copy) + # figure out what type of data we have + data_ed, self.data_format = check_data_input_format(data_ed) #: original data (with subtracted mean) self.origin = data_ed + # for `subtract_mean` and `padding`, we could use `np.atleast_3d` #: whether padding is enabled self.padding = padding #: whether the mean was subtracted @@ -81,14 +96,13 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True): # remove contributions of the central band # (this affects more than one pixel in the FFT # because of zero-padding) - data_ed -= data_ed.mean() + data_ed = mean_3d(data_ed) if padding: # zero padding size is next order of 2 logfact = np.log(padding * max(data_ed.shape)) order = int(2 ** np.ceil(logfact / np.log(2))) - # this is faster than np.pad - datapad = np.zeros((order, order), dtype=dtype) - datapad[:data_ed.shape[0], :data_ed.shape[1]] = data_ed + + datapad = padding_3d(data_ed, order, dtype) #: padded input data self.origin_padded = datapad data_ed = datapad @@ -175,7 +189,7 @@ def filter(self, filter_name: str, filter_size: float, and must be between 0 and `max(fft_shape)/2` freq_pos: tuple of floats The position of the filter in frequency coordinates as - returned by :func:`nunpy.fft.fftfreq`. + returned by :func:`numpy.fft.fftfreq`. scale_to_filter: bool or float Crop the image in Fourier space after applying the filter, effectively removing surplus (zero-padding) data and @@ -220,36 +234,41 @@ def filter(self, filter_name: str, filter_size: float, filter_name=filter_name, filter_size=filter_size, freq_pos=freq_pos, - fft_shape=self.fft_origin.shape) + # only take shape of a single fft + fft_shape=self.fft_origin.shape[-2:]) fft_filtered = self.fft_origin * filt_array - px = int(freq_pos[0] * self.shape[0]) - py = int(freq_pos[1] * self.shape[1]) - fft_used = np.roll(np.roll(fft_filtered, -px, axis=0), -py, axis=1) + px = int(freq_pos[0] * self.shape[-2]) + py = int(freq_pos[1] * self.shape[-1]) + fft_used = np.roll(np.roll( + fft_filtered, -px, axis=-2), -py, axis=-1) if scale_to_filter: # Determine the size of the cropping region. # We compute the "radius" of the region, so we can # crop the data left and right from the center of the # Fourier domain. - osize = fft_filtered.shape[0] # square shaped + osize = fft_filtered.shape[-2] # square shaped crad = int(np.ceil(filter_size * osize * scale_to_filter)) ccent = osize // 2 cslice = slice(ccent - crad, ccent + crad) # We now have the interesting peak already shifted to # the first entry of our array in `shifted`. - fft_used = fft_used[cslice, cslice] + fft_used = fft_used[:, cslice, cslice] field = self._ifft(np.fft.ifftshift(fft_used)) + if self.padding: # revert padding - sx, sy = self.origin.shape + sx, sy = self.origin.shape[-2:] if scale_to_filter: sx = int(np.ceil(sx * 2 * crad / osize)) sy = int(np.ceil(sy * 2 * crad / osize)) - field = field[:sx, :sy] + + field = field[:, :sx, :sy] + if scale_to_filter: # Scale the absolute value of the field. This does not # have any influence on the phase, but on the amplitude. - field *= (2 * crad / osize)**2 + field *= (2 * crad / osize) ** 2 # Add FFT to cache # (The cache will only be cleared if this instance is deleted) FFTCache.add_item(weakref_key, self.fft_origin, diff --git a/qpretrieve/interfere/base.py b/qpretrieve/interfere/base.py index f1a637a..c5f3ec0 100644 --- a/qpretrieve/interfere/base.py +++ b/qpretrieve/interfere/base.py @@ -2,7 +2,8 @@ import numpy as np -from ..fourier import get_best_interface +from ..fourier import get_best_interface, get_available_interfaces +from ..fourier.base import FFTFilter class BaseInterferogram(ABC): @@ -15,11 +16,19 @@ class BaseInterferogram(ABC): "invert_phase": False, } - def __init__(self, data, subtract_mean=True, padding=2, copy=True, + def __init__(self, data, fft_interface: FFTFilter = None, + subtract_mean=True, padding=2, copy=True, **pipeline_kws): """ Parameters ---------- + fft_interface: FFTFilter + A Fourier transform interface. + See :func:`qpretrieve.fourier.get_available_interfaces` + to get a list of implemented interfaces. + Default is None, which will use + :func:`qpretrieve.fourier.get_best_interface`. This is in line + with old behaviour. subtract_mean: bool If True, remove the mean of the hologram before performing the Fourier transform. This setting is recommended as it @@ -38,15 +47,24 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True, Any additional keyword arguments for :func:`run_pipeline` as defined in :const:`default_pipeline_kws`. """ - ff_iface = get_best_interface() - if len(data.shape) == 3: - # take the first slice (we have alpha or RGB information) - data = data[:, :, 0] + if fft_interface == 'auto' or fft_interface is None: + self.ff_iface = get_best_interface() + else: + if fft_interface in get_available_interfaces(): + self.ff_iface = fft_interface + else: + raise ValueError( + f"User-chosen FFT Interface '{fft_interface}' is not " + f"available. The available interfaces are: " + f"{get_available_interfaces()}.\n" + f"You can use `fft_interface='auto'` to get the best " + f"available interface.") + #: qpretrieve Fourier transform interface class - self.fft = ff_iface(data=data, - subtract_mean=subtract_mean, - padding=padding, - copy=copy) + self.fft = self.ff_iface(data=data, + subtract_mean=subtract_mean, + padding=padding, + copy=copy) #: originally computed Fourier transform self.fft_origin = self.fft.fft_origin #: filtered Fourier data from last run of `run_pipeline` @@ -94,18 +112,18 @@ def compute_filter_size(self, filter_size, filter_size_interpretation, raise ValueError("For sideband distance interpretation, " "`filter_size` must be between 0 and 1; " f"got '{filter_size}'!") - fsize = np.sqrt(np.sum(np.array(sideband_freq)**2)) * filter_size + fsize = np.sqrt(np.sum(np.array(sideband_freq) ** 2)) * filter_size elif filter_size_interpretation == "frequency index": # filter size given in Fourier index (number of Fourier pixels) # The user probably does not know that we are padding in # Fourier space, so we use the unpadded size and translate it. - if filter_size <= 0 or filter_size >= self.fft.shape[0] / 2: + if filter_size <= 0 or filter_size >= self.fft.shape[-2] / 2: raise ValueError("For frequency index interpretation, " + "`filter_size` must be between 0 and " - + f"{self.fft.shape[0] / 2}, got " + + f"{self.fft.shape[-2] / 2}, got " + f"'{filter_size}'!") # convert to frequencies (compatible with fx and fy) - fsize = filter_size / self.fft.shape[0] + fsize = filter_size / self.fft.shape[-2] else: raise ValueError("Invalid value for `filter_size_interpretation`: " + f"'{filter_size_interpretation}'") diff --git a/qpretrieve/interfere/if_oah.py b/qpretrieve/interfere/if_oah.py index 73bd4bd..707d253 100644 --- a/qpretrieve/interfere/if_oah.py +++ b/qpretrieve/interfere/if_oah.py @@ -1,6 +1,7 @@ import numpy as np from .base import BaseInterferogram +from ..data_input import revert_to_data_input_format class OffAxisHologram(BaseInterferogram): @@ -73,7 +74,7 @@ def run_pipeline(self, **pipeline_kws): if pipeline_kws["sideband_freq"] is None: pipeline_kws["sideband_freq"] = find_peak_cosine( - self.fft.fft_origin) + self.fft.fft_origin[0]) # convert filter_size to frequency coordinates fsize = self.compute_filter_size( @@ -92,6 +93,7 @@ def run_pipeline(self, **pipeline_kws): if pipeline_kws["invert_phase"]: field.imag *= -1 + field = revert_to_data_input_format(self.fft.data_format, field) self._field = field self._phase = None self._amplitude = None @@ -101,7 +103,7 @@ def run_pipeline(self, **pipeline_kws): def find_peak_cosine(ft_data, copy=True): - """Find the side band position of a regular off-axis hologram + """Find the side band position of a 2d regular off-axis hologram The Fourier transform of a cosine function (known as the striped fringe pattern in off-axis holography) results in diff --git a/qpretrieve/interfere/if_qlsi.py b/qpretrieve/interfere/if_qlsi.py index 38f2c39..2b107ec 100644 --- a/qpretrieve/interfere/if_qlsi.py +++ b/qpretrieve/interfere/if_qlsi.py @@ -47,7 +47,7 @@ def amplitude(self): @property def field(self): if self._field is None: - self._field = self.amplitude * np.exp(1j*2*np.pi*self.phase) + self._field = self.amplitude * np.exp(1j * 2 * np.pi * self.phase) return self._field @property @@ -120,7 +120,7 @@ def run_pipeline(self, **pipeline_kws): if pipeline_kws["sideband_freq"] is None: pipeline_kws["sideband_freq"] = find_peaks_qlsi( - self.fft.fft_origin) + self.fft.fft_origin[0]) # convert filter_size to frequency coordinates fsize = self.compute_filter_size( @@ -172,8 +172,14 @@ def run_pipeline(self, **pipeline_kws): # Obtain the phase gradients in x and y by taking the argument # of Hx and Hy. - px = unwrap_phase(np.angle(hx)) - py = unwrap_phase(np.angle(hy)) + # need to do this along the z axis, as skimage `unwrap_3d` does not + # work for our use-case + # todo: maybe use np.unwrap for the xy axes instead + px = np.zeros_like(hx, dtype=float) + py = np.zeros_like(hy, dtype=float) + for i, (_hx, _hy) in enumerate(zip(hx, hy)): + px[i] = unwrap_phase(np.angle(_hx)) + py[i] = unwrap_phase(np.angle(_hy)) # Determine the angle by which we have to rotate the gradients in # order for them to be aligned with x and y. This angle is defined @@ -183,15 +189,15 @@ def run_pipeline(self, **pipeline_kws): # Pad the gradient information so that we can rotate with cropping # (keeping the image shape the same). # TODO: Make padding dependent on rotation angle to save time? - sx, sy = px.shape - gradpad1 = np.pad(px, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + sx, sy = px.shape[-2:] + gradpad1 = np.pad(px, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), mode="constant", constant_values=0) - gradpad2 = np.pad(py, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + gradpad2 = np.pad(py, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), mode="constant", constant_values=0) # Perform rotation of the gradients. - rotated1 = rotate_noreshape(gradpad1, -angle) - rotated2 = rotate_noreshape(gradpad2, -angle) + rotated1 = rotate_noreshape(gradpad1, -angle, axes=(-1, -2)) + rotated2 = rotate_noreshape(gradpad2, -angle, axes=(-1, -2)) # Retrieve the wavefront by integrating the vectorial components # (integrate the total differential). This magical approach @@ -204,22 +210,22 @@ def run_pipeline(self, **pipeline_kws): copy=False) # Compute the frequencies that correspond to the frequencies of the # Fourier-transformed image. - fx = np.fft.fftfreq(rfft.shape[0]).reshape(-1, 1) - fy = np.fft.fftfreq(rfft.shape[1]).reshape(1, -1) - fxy = -2*np.pi*1j * (fx + 1j*fy) - fxy[0, 0] = 1 + fx = np.fft.fftfreq(rfft.shape[-2]).reshape(rfft.shape[0], -1, 1) + fy = np.fft.fftfreq(rfft.shape[-1]).reshape(rfft.shape[0], 1, -1) + fxy = -2 * np.pi * 1j * (fx + 1j * fy) + fxy[:, 0, 0] = 1 # The wavefront is the real part of the inverse Fourier transform # of the filtered (divided by frequencies) data. - wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin)/fxy).real + wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin) / fxy).real # Rotate the wavefront back and crop it so that the FOV matches # the input data. - raw_wavefront = rotate_noreshape(wfr, - angle)[sx//2:-sx//2, sy//2:-sy//2] + raw_wavefront = rotate_noreshape( + wfr, angle, axes=(-1, -2))[:, sx // 2:-sx // 2, sy // 2:-sy // 2] # Multiply by qlsi pitch term and the scaling factor to get # the quantitative wavefront. - scaling_factor = self.fft_origin.shape[0] / wfr.shape[0] + scaling_factor = self.fft_origin.shape[-2] / wfr.shape[-2] raw_wavefront *= qlsi_pitch_term * scaling_factor self._phase = raw_wavefront / wavelength * 2 * np.pi @@ -230,6 +236,8 @@ def run_pipeline(self, **pipeline_kws): self.pipeline_kws.update(pipeline_kws) + # raw_wavefront = revert_to_data_input_format( + # self.fft.data_format, raw_wavefront) self.wavefront = raw_wavefront return raw_wavefront @@ -285,24 +293,25 @@ def find_peaks_qlsi(ft_data, periodicity=4, copy=True): ft_data[:, cy - 3:cy + 3] = 0 # circular bandpass according to periodicity - fx = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[0])).reshape(-1, 1) - fy = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[1])).reshape(1, -1) - frmask1 = np.sqrt(fx**2 + fy**2) > 1/(periodicity*.8) + fx = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[-2])).reshape(-1, 1) + fy = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[-1])).reshape(1, -1) + frmask1 = np.sqrt(fx ** 2 + fy ** 2) > 1 / (periodicity * .8) frmask2 = np.sqrt(fx ** 2 + fy ** 2) < 1 / (periodicity * 1.2) ft_data[np.logical_or(frmask1, frmask2)] = 0 # find the peak in the left part - am1 = np.argmax(np.abs(ft_data*(fy < 0))) + am1 = np.argmax(np.abs(ft_data * (fy < 0))) i1y = am1 % oy i1x = int((am1 - i1y) / oy) return fx[i1x, 0], fy[0, i1y] -def rotate_noreshape(arr, angle, mode="mirror", reshape=False): +def rotate_noreshape(arr, angle, axes, mode="mirror", reshape=False): return scipy.ndimage.rotate( arr, # input angle=np.rad2deg(angle), # angle + axes=axes, reshape=reshape, # reshape order=0, # order mode=mode, # mode diff --git a/qpretrieve/utils.py b/qpretrieve/utils.py new file mode 100644 index 0000000..e8142e5 --- /dev/null +++ b/qpretrieve/utils.py @@ -0,0 +1,30 @@ +import numpy as np + + +def mean_2d(data): + data -= data.mean() + return data + + +def mean_3d(data): + # calculate mean of the images along the z-axis. + # The mean array here is (1000,), so we need to add newaxes for subtraction + # (1000, 5, 5) -= (1000, 1, 1) + data -= data.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + return data + + +def padding_2d(data, order, dtype): + # this is faster than np.pad + datapad = np.zeros((order, order), dtype=dtype) + # we could of course use np.atleast_3d here + datapad[:data.shape[0], :data.shape[1]] = data + return datapad + + +def padding_3d(data, order, dtype): + z, y, x = data.shape + # this is faster than np.pad + datapad = np.zeros((z, order, order), dtype=dtype) + datapad[:, :y, :x] = data + return datapad diff --git a/setup.py b/setup.py index 3472257..bfc597f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,6 @@ from setuptools import setup, find_packages import sys - author = u"Paul Müller" authors = [author] description = 'library for phase retrieval from holograms' @@ -27,8 +26,10 @@ "numpy>=1.9.0", "scikit-image>=0.11.0", "scipy>=0.18.0", - ], - extras_require={"FFTW": "pyfftw>=0.12.0"}, + ], + extras_require={ + "FFTW": "pyfftw>=0.12.0", + }, python_requires='>=3.10, <4', keywords=["digital holographic microscopy", "optics", @@ -41,6 +42,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Intended Audience :: Science/Research' - ], + ], platforms=['ALL'], - ) +) diff --git a/tests/conftest.py b/tests/conftest.py index 759db0d..cdbc989 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,11 @@ import shutil import tempfile import time +import numpy as np -import qpretrieve +import pytest +import qpretrieve TMPDIR = tempfile.mkdtemp(prefix=time.strftime( "qpretrieve_test_%H.%M_")) @@ -22,3 +24,22 @@ def pytest_configure(config): # creating FFTW wisdom. Also, it makes the tests more reproducible # by sticking to simple numpy FFTs. qpretrieve.fourier.PREFERRED_INTERFACE = "FFTFilterNumpy" + + +@pytest.fixture +def hologram(size=64): + x = np.arange(size).reshape(-1, 1) - size / 2 + y = np.arange(size).reshape(1, -1) - size / 2 + + amp = np.linspace(.9, 1.1, size * size).reshape(size, size) + pha = np.linspace(0, 2, size * size).reshape(size, size) + + rad = x ** 2 + y ** 2 > (size / 3) ** 2 + pha[rad] = 0 + amp[rad] = 1 + + # frequencies must match pixel in Fourier space + kx = 2 * np.pi * -.3 + ky = 2 * np.pi * -.3 + image = (amp ** 2 + np.sin(kx * x + ky * y + pha) + 1) * 255 + return image diff --git a/tests/test_data_input.py b/tests/test_data_input.py new file mode 100644 index 0000000..bd04222 --- /dev/null +++ b/tests/test_data_input.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +from qpretrieve.data_input import check_data_input_format + + +def test_check_data_input_2d(): + data = np.zeros(shape=(256, 256)) + + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data) + assert data_format == "2d" + + +def test_check_data_input_3d_image_stack(): + data = np.zeros(shape=(50, 256, 256)) + + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (50, 256, 256) + assert np.array_equal(data_new, data) + assert data_format == "3d" + + +def test_check_data_input_3d_rgb(): + data = np.zeros(shape=(256, 256, 3)) + + with pytest.warns(UserWarning): + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert data_format == "rgb" + + +def test_check_data_input_3d_rgba(): + data = np.zeros(shape=(256, 256, 4)) + + with pytest.warns(UserWarning): + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert data_format == "rgba" diff --git a/tests/test_fourier_base.py b/tests/test_fourier_base.py index c364ebb..2f58ea0 100644 --- a/tests/test_fourier_base.py +++ b/tests/test_fourier_base.py @@ -136,12 +136,23 @@ def test_scale_to_filter_qlsi(): } ifh = interfere.QLSInterferogram(image, **pipeline_kws) - ifh.run_pipeline() + raw_wavefront = ifh.run_pipeline() + assert raw_wavefront.shape == (1, 720, 720) + assert ifh.phase.shape == (1, 720, 720) + assert ifh.amplitude.shape == (1, 720, 720) + assert ifh.field.shape == (1, 720, 720) ifr = interfere.QLSInterferogram(refer, **pipeline_kws) ifr.run_pipeline() + assert ifr.phase.shape == (1, 720, 720) + assert ifr.amplitude.shape == (1, 720, 720) + assert ifr.field.shape == (1, 720, 720) + + ifh_phase = ifh.phase[0] + ifr_phase = ifr.phase[0] + + phase = unwrap_phase(ifh_phase - ifr_phase) - phase = unwrap_phase(ifh.phase - ifr.phase) assert phase.shape == (720, 720) assert np.allclose(phase.mean(), 0.12434563269684816, atol=1e-6) @@ -152,6 +163,6 @@ def test_scale_to_filter_qlsi(): ifr.run_pipeline(**pipeline_kws_scale) phase_scaled = unwrap_phase(ifh.phase - ifr.phase) - assert phase_scaled.shape == (126, 126) + assert phase_scaled.shape == (1, 126, 126) assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6) diff --git a/tests/test_interfere_base.py b/tests/test_interfere_base.py new file mode 100644 index 0000000..dec2027 --- /dev/null +++ b/tests/test_interfere_base.py @@ -0,0 +1,40 @@ +import pathlib +import numpy as np +import pytest + +import qpretrieve + +data_path = pathlib.Path(__file__).parent / "data" + + +def test_interfere_base_best_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_choose_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram( + data=edata["data"], + fft_interface=qpretrieve.fourier.FFTFilterNumpy) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_bad_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + with pytest.raises(ValueError): + _ = qpretrieve.OffAxisHologram( + data=edata["data"], + fft_interface="MyReallyCoolFFTInterface") diff --git a/tests/test_oah_from_qpimage.py b/tests/test_oah_from_qpimage.py index 838b180..e80bd33 100644 --- a/tests/test_oah_from_qpimage.py +++ b/tests/test_oah_from_qpimage.py @@ -4,24 +4,10 @@ import qpretrieve from qpretrieve.interfere import if_oah - - -def hologram(size=64): - x = np.arange(size).reshape(-1, 1) - size / 2 - y = np.arange(size).reshape(1, -1) - size / 2 - - amp = np.linspace(.9, 1.1, size * size).reshape(size, size) - pha = np.linspace(0, 2, size * size).reshape(size, size) - - rad = x**2 + y**2 > (size / 3)**2 - pha[rad] = 0 - amp[rad] = 1 - - # frequencies must match pixel in Fourier space - kx = 2 * np.pi * -.3 - ky = 2 * np.pi * -.3 - image = (amp**2 + np.sin(kx * x + ky * y + pha) + 1) * 255 - return image +from qpretrieve.fourier import FFTFilterNumpy, FFTFilterPyFFTW +from qpretrieve.data_input import ( + _convert_2d_to_3d, _revert_3d_to_rgb, _revert_3d_to_rgba, +) def test_find_sideband(): @@ -36,24 +22,26 @@ def test_find_sideband(): def test_fourier2dpad(): - data = np.zeros((100, 120)) + y, x = 100, 120 + data = np.zeros((y, x)) fft1 = qpretrieve.fourier.FFTFilterNumpy(data) - assert fft1.shape == (256, 256) + assert fft1.shape == (1, 256, 256) fft2 = qpretrieve.fourier.FFTFilterNumpy(data, padding=False) - assert fft2.shape == data.shape + assert fft2.shape == (1, y, x) -def test_get_field_error_bad_filter_size(): - data = hologram() +def test_get_field_error_bad_filter_size(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, match="must be between 0 and 1"): holo.run_pipeline(filter_size=2) -def test_get_field_error_bad_filter_size_interpretation_frequency_index(): - data = hologram(size=64) +def test_get_field_error_bad_filter_size_interpretation_frequency_index( + hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, @@ -62,8 +50,8 @@ def test_get_field_error_bad_filter_size_interpretation_frequency_index(): filter_size=64) -def test_get_field_error_invalid_interpretation(): - data = hologram() +def test_get_field_error_invalid_interpretation(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, @@ -71,8 +59,8 @@ def test_get_field_error_invalid_interpretation(): holo.run_pipeline(filter_size_interpretation="blequency") -def test_get_field_filter_names(): - data = hologram() +def test_get_field_filter_names(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs = dict(sideband=+1, @@ -84,11 +72,11 @@ def test_get_field_filter_names(): r_smooth_disk = holo.run_pipeline(filter_name="smooth disk", **kwargs) assert np.allclose(r_smooth_disk[32, 32], - 108.36438759594623-67.1806221692573j) + 108.36438759594623 - 67.1806221692573j) r_gauss = holo.run_pipeline(filter_name="gauss", **kwargs) assert np.allclose(r_gauss[32, 32], - 108.2914187451138-67.1823527237741j) + 108.2914187451138 - 67.1823527237741j) r_square = holo.run_pipeline(filter_name="square", **kwargs) assert np.allclose( @@ -96,7 +84,7 @@ def test_get_field_filter_names(): r_smsquare = holo.run_pipeline(filter_name="smooth square", **kwargs) assert np.allclose( - r_smsquare[32, 32], 108.36651862466393-67.17988960794392j) + r_smsquare[32, 32], 108.36651862466393 - 67.17988960794392j) r_tukey = holo.run_pipeline(filter_name="tukey", **kwargs) assert np.allclose( @@ -110,10 +98,10 @@ def test_get_field_filter_names(): assert False, "unknown filter accepted" -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index(size): +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=["hologram"]) +def test_get_field_interpretation_fourier_index(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) ft_data = holo.fft_origin @@ -121,23 +109,26 @@ def test_get_field_interpretation_fourier_index(size): fsx, fsy = holo.pipeline_kws["sideband_freq"] kwargs1 = dict(filter_name="disk", - filter_size=1/3, + filter_size=1 / 3, filter_size_interpretation="sideband distance") res1 = holo.run_pipeline(**kwargs1) - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] kwargs2 = dict(filter_name="disk", filter_size=filter_size_fi, filter_size_interpretation="frequency index", ) res2 = holo.run_pipeline(**kwargs2) + + assert res1.shape == hologram.shape + assert res2.shape == hologram.shape assert np.all(res1 == res2) -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index_control(size): +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_control(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) ft_data = holo.fft_origin @@ -147,12 +138,12 @@ def test_get_field_interpretation_fourier_index_control(size): evil_factor = 1.1 kwargs1 = dict(filter_name="disk", - filter_size=1/3 * evil_factor, + filter_size=1 / 3 * evil_factor, filter_size_interpretation="sideband distance" ) res1 = holo.run_pipeline(**kwargs1) - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] kwargs2 = dict(filter_name="disk", filter_size=filter_size_fi, filter_size_interpretation="frequency index", @@ -161,11 +152,12 @@ def test_get_field_interpretation_fourier_index_control(size): assert not np.all(res1 == res2) -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) @pytest.mark.parametrize("filter_size", [17, 17.01]) -def test_get_field_interpretation_fourier_index_mask_1(size, filter_size): +def test_get_field_interpretation_fourier_index_mask_1(hologram, filter_size): """Make sure filter size in Fourier space pixels is correct""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs2 = dict(filter_name="disk", @@ -178,13 +170,14 @@ def test_get_field_interpretation_fourier_index_mask_1(size, filter_size): # We get 17*2+1, because we measure from the center of Fourier # space and a pixel is included if its center is withing the # perimeter of the disk. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 + 1 + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 + 1 -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) -def test_get_field_interpretation_fourier_index_mask_2(size): +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_mask_2(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs2 = dict(filter_name="disk", @@ -196,11 +189,11 @@ def test_get_field_interpretation_fourier_index_mask_2(size): # We get two points less than in the previous test, because we # loose on each side of the spectrum. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 - 1 + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 - 1 -def test_get_field_int_copy(): - data = hologram() +def test_get_field_int_copy(hologram): + data = hologram data = np.array(data, dtype=int) kwargs = dict(filter_size=1 / 3) @@ -218,8 +211,8 @@ def test_get_field_int_copy(): assert np.all(res1 == res3) -def test_get_field_sideband(): - data = hologram() +def test_get_field_sideband(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) holo.run_pipeline() invert_phase = holo.pipeline_kws["invert_phase"] @@ -232,10 +225,10 @@ def test_get_field_sideband(): assert np.all(res1 == res2) -def test_get_field_three_axes(): - data1 = hologram() +def test_get_field_three_axes(hologram): + data1 = hologram # create a copy with empty entry in third axis - data2 = np.zeros((data1.shape[0], data1.shape[1], 2)) + data2 = np.zeros((data1.shape[0], data1.shape[1], 3)) data2[:, :, 0] = data1 holo1 = qpretrieve.OffAxisHologram(data1) @@ -245,4 +238,56 @@ def test_get_field_three_axes(): filter_size=1 / 3) res1 = holo1.run_pipeline(**kwargs) res2 = holo2.run_pipeline(**kwargs) - assert np.all(res1 == res2) + + assert res1.shape == (data1.shape[0], data1.shape[1]) + assert res2.shape == (data1.shape[0], data1.shape[1], 3) + + assert np.all(res1 == res2[:, :, 0]) + + +def test_get_field_compare_FFTFilters(hologram): + data1 = hologram + + holo1 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterNumpy, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res1 = holo1.run_pipeline(**kwargs) + assert res1.shape == (64, 64) + + holo2 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterPyFFTW, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res2 = holo2.run_pipeline(**kwargs) + assert res2.shape == (64, 64) + + assert not np.all(res1 == res2) + + +def test_field_format_consistency(hologram): + """The data format provided by the user should be returned""" + data_2d = hologram + + # 2d data format + holo1 = qpretrieve.OffAxisHologram(data_2d) + res1 = holo1.run_pipeline() + assert res1.shape == data_2d.shape + + # 3d data format + data_3d, _ = _convert_2d_to_3d(data_2d) + holo_3d = qpretrieve.OffAxisHologram(data_3d) + res_3d = holo_3d.run_pipeline() + assert res_3d.shape == data_3d.shape + + # rgb data format + data_rgb = _revert_3d_to_rgb(data_3d) + holo_rgb = qpretrieve.OffAxisHologram(data_rgb) + res_rgb = holo_rgb.run_pipeline() + assert res_rgb.shape == data_rgb.shape + + # rgba data format + data_rgba = _revert_3d_to_rgba(data_3d) + holo_rgba = qpretrieve.OffAxisHologram(data_rgba) + res_rgba = holo_rgba.run_pipeline() + assert res_rgba.shape == data_rgba.shape diff --git a/tests/test_qlsi.py b/tests/test_qlsi.py index f839114..c541919 100644 --- a/tests/test_qlsi.py +++ b/tests/test_qlsi.py @@ -2,8 +2,9 @@ import h5py import numpy as np -import qpretrieve +from skimage.restoration import unwrap_phase +import qpretrieve data_path = pathlib.Path(__file__).parent / "data" @@ -29,3 +30,138 @@ def test_qlsi_phase(): assert qlsi.phase.argmax() == 242294 assert np.allclose(qlsi.phase.max(), 0.9343997734657971, atol=0, rtol=1e-12) + + +def test_qlsi_fftfreq_reshape_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + fx_2d = np.fft.fftfreq(data_2d.shape[-1]).reshape(-1, 1) + fx_3d = np.fft.fftfreq(data_3d.shape[-1]).reshape(data_3d.shape[0], -1, 1) + + fy_2d = np.fft.fftfreq(data_2d.shape[-2]).reshape(1, -1) + fy_3d = np.fft.fftfreq(data_3d.shape[-2]).reshape(data_3d.shape[0], 1, -1) + + assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fy_2d, fy_3d[0]) + + +def test_qlsi_unwrap_phase_2d_3d(): + """ + Check whether skimage unwrap_2d and unwrap_3d give the same result. + In other words, does unwrap_3d apply th unwrapping along the z axis. + + Answer is no, they are different. unwrap_3d is designed for 3D data that + is to be unwrapped on all axes at once. + + """ + with h5py.File(data_path / "qlsi_paa_bead.h5") as h5: + image = h5["0"][:] + + # Standard analysis pipeline + pipeline_kws = { + 'wavelength': 550e-9, + 'qlsi_pitch_term': 1.87711e-08, + 'filter_name': "disk", + 'filter_size': 180, + 'filter_size_interpretation': "frequency index", + 'scale_to_filter': False, + 'invert_phase': False + } + + data_2d = image + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + ft_2d = qpretrieve.fourier.FFTFilterNumpy(data_2d, subtract_mean=False) + ft_3d = qpretrieve.fourier.FFTFilterNumpy(data_3d, subtract_mean=False) + + pipeline_kws["sideband_freq"] = qpretrieve.interfere. \ + if_qlsi.find_peaks_qlsi(ft_2d.fft_origin[0]) + + hx_2d = ft_2d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + hx_3d = ft_3d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + + assert np.array_equal(hx_2d, hx_3d) + + px_2d = unwrap_phase(np.angle(hx_2d[0])) + + px_3d_loop = np.zeros_like(hx_3d) + for i, _hx in enumerate(hx_3d): + px_3d_loop[i] = unwrap_phase(np.angle(_hx)) + + assert np.array_equal(px_2d, px_3d_loop[0]) # this passes + + px_3d = unwrap_phase(np.angle(hx_3d)) # this is not equivalent + assert not np.array_equal(px_2d, px_3d[0]) + + +def test_qlsi_rotate_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + rot_2d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_2d, + angle=2, + axes=(1, 0), # this was the default used before + reshape=False, + ) + rot_3d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-1, -2), # the y and x axes + reshape=False, + ) + rot_3d_2 = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-2, -1), # the y and x axes + reshape=False, + ) + + assert rot_2d.dtype == rot_3d.dtype + assert np.array_equal(rot_2d, rot_3d[0]) + assert np.array_equal(rot_2d, rot_3d_2[0]) + + +def test_qlsi_pad_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + sx, sy = data_2d.shape[-2:] + gradpad_2d = np.pad( + data_2d, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + gradpad_3d = np.pad( + data_3d, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + + assert gradpad_2d.dtype == gradpad_3d.dtype + assert np.array_equal(gradpad_2d, gradpad_3d[0]) + + +def test_fxy_complex_mul(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + assert np.array_equal(data_2d, data_3d[0]) + + # 2d + fx_2d = data_2d.reshape(-1, 1) + fy_2d = data_2d.reshape(1, -1) + fxy_2d = -2 * np.pi * 1j * (fx_2d + 1j * fy_2d) + fxy_2d[0, 0] = 1 + + # 3d + fx_3d = data_3d.reshape(data_3d.shape[0], -1, 1) + fy_3d = data_3d.reshape(data_3d.shape[0], 1, -1) + fxy_3d = -2 * np.pi * 1j * (fx_3d + 1j * fy_3d) + fxy_3d[:, 0, 0] = 1 + + assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fxy_2d, fxy_3d[0]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c7be6d5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,44 @@ +import numpy as np + +from qpretrieve.utils import padding_2d, padding_3d, mean_2d, mean_3d + + +def test_mean_subtraction(): + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + data_2d = mean_2d(data_2d) + data_3d = mean_3d(data_3d) + + assert np.array_equal(data_3d[ind], data_2d) + + +def test_mean_subtraction_consistent_2d_3d(): + """Probably a bit too cumbersome, and changes the default 2d pipeline.""" + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + # too cumbersome + data_2d = np.atleast_3d(data_2d) + data_2d = np.swapaxes(np.swapaxes(data_2d, 0, 2), 1, 2) + data_2d -= data_2d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + data_3d = np.atleast_3d(data_3d.copy()) + data_3d -= data_3d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + assert np.array_equal(data_3d[ind], data_2d[0]) + + +def test_batch_padding(): + data_3d = np.random.rand(1000, 100, 320).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + order = 512 + dtype = float + + data_2d_padded = padding_2d(data_2d, order, dtype) + data_3d_padded = padding_3d(data_3d, order, dtype) + + assert np.array_equal(data_3d_padded[ind], data_2d_padded)