Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(old) Move from 2d to 3d array operations #11

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f09a64f
enh: add scipy fft interface
Aug 26, 2024
41b77e4
tests: correct path for ci use
Aug 26, 2024
97e7c5c
tests: fft interface checks and lint code
Aug 26, 2024
93618fb
docs: update docstrings and docs example output image
Aug 26, 2024
b19a817
enh, tests: add cupy3D class and basic tests for debugging
Sep 17, 2024
8525abd
enh: making the code 3d-proof
Sep 18, 2024
839214e
test: compare 2D and 3D mean value consistency
Nov 7, 2024
7d1ca55
enh: allow fourier base to subtract mean from 3D data arrays
Nov 7, 2024
4f01683
enh: create utility functions for padding and subt_mean
Nov 7, 2024
63b0c4d
tests: add tests for new util funcs and comparing 2D and 3D Fourier p…
Nov 7, 2024
5054745
tests: remove use of matplotlib from tests
Nov 7, 2024
5319043
docs: create speed comparison example for Cupy 3d
Nov 7, 2024
35975f3
test: reorg the cupy tests for clarity
Nov 8, 2024
23d367f
ci: ignore cupy tests for now during the cicd pipeline
Nov 8, 2024
b812c07
docs: update figure labels; ci: correct github actions syntax
Nov 8, 2024
2361769
fix: correct define FFTFilters upon import
Nov 8, 2024
32cf2ce
test: check ci tests to see if Pyfftw is the issue
Nov 8, 2024
59f67ec
docs: lint examples
Nov 8, 2024
2468569
docs: update README
Nov 8, 2024
7d9f1ae
ref: use negative indexing for np array shape
Nov 8, 2024
d2e2f61
enh: add use of 3D array for the FFTFilter init, not incl padding
Nov 8, 2024
e427cc5
enh: ensure ifft with padding works with 3D stack
Nov 8, 2024
d7d84f5
enh: ensure all data is converted to 3D image stack
Nov 20, 2024
2da1582
enh: add data format conversion functions
Nov 20, 2024
57947e6
ref: use the data format conversion convenience fnuctions to handle f…
Nov 20, 2024
1ac98ff
test: refactor relevant oah tests to expect new format and shape
Nov 20, 2024
fe33bf0
enh: add rgb warning for user
Nov 20, 2024
568d511
test: ensure the users provided data format is returned
Nov 20, 2024
7d88bb9
enh: add 3d array usage to qsli
Nov 20, 2024
d7457e1
ref: align the 2d qlsi code to 3d
Nov 21, 2024
528a038
fix: match qlsi 2d with new 3d implementation
Dec 5, 2024
33c14b8
ref: remove mentioned of cupy and scipy FFTFilters
Dec 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.10']
python-version: ['3.10', '3.11']
os: [macos-latest, ubuntu-latest, windows-latest]

steps:
Expand Down Expand Up @@ -39,7 +39,8 @@ jobs:
pip freeze
- name: Test with pytest
run: |
coverage run --source=qpretrieve -m pytest tests
# ignore the cupy imports, as we don't have a gpu-enabled pipeline setup
coverage run --source=qpretrieve -m pytest tests --ignore=tests/test_cupy_gpu
- name: Lint with flake8
run: |
flake8 .
Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ To install the requirements for building the documentation, run

To compile the documentation, run

cd docs
sphinx-build . _build


Expand Down
7 changes: 3 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
sphinx==4.3.0
sphinxcontrib.bibtex>=2.0
sphinx_rtd_theme==1.0

sphinx
sphinxcontrib.bibtex
sphinx_rtd_theme
3 changes: 1 addition & 2 deletions docs/sec_code_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,20 @@ Fourier transform methods
=========================

.. _sec_code_fourier_numpy:

Numpy
-----
.. automodule:: qpretrieve.fourier.ff_numpy
:members:
:inherited-members:

.. _sec_code_fourier_pyfftw:

PyFFTW
------
.. automodule:: qpretrieve.fourier.ff_pyfftw
:members:
:inherited-members:


.. _sec_code_ifer:

Interference image analysis
Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
matplotlib
84 changes: 84 additions & 0 deletions qpretrieve/data_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import warnings

allowed_data_formats = [
"rgb",
"rgba",
"3d",
"2d",
]


def check_data_input_format(data_input):
"""Figure out what data input is provided."""
if len(data_input.shape) == 3:
if data_input.shape[-1] in [1, 2, 3]:
# take the first slice (we have alpha or RGB information)
data, data_format = _convert_rgb_to_3d(data_input)
elif data_input.shape[-1] == 4:
# take the first slice (we have alpha or RGB information)
data, data_format = _convert_rgba_to_3d(data_input)
else:
# we have a 3D image stack (z, y, x)
data, data_format = data_input, "3d"
elif len(data_input.shape) == 2:
# we have a 2D image (y, x). convert to (z, y, z)
data, data_format = _convert_2d_to_3d(data_input)
else:
raise ValueError(f"data_input shape must be 2d or 3d, "
f"got shape {data_input.shape}.")
return data.copy(), data_format


def revert_to_data_input_format(data_format, field):
"""Convert the outputted field shape to the original input shape,
for user convenience."""
assert data_format in allowed_data_formats
assert len(field.shape) == 3, "the field should be 3d"
field = field.copy()
if data_format == "rgb":
field = _revert_3d_to_rgb(field)
elif data_format == "rgba":
field = _revert_3d_to_rgba(field)
elif data_format == "3d":
field = field
else:
field = _revert_3d_to_2d(field)
return field


def _convert_rgb_to_3d(data_input):
data = data_input[:, :, 0]
data = data[np.newaxis, :, :]
data_format = "rgb"
warnings.warn(f"Format of input data detected as {data_format}. "
f"The first channel will be used for processing")
return data, data_format


def _convert_rgba_to_3d(data_input):
data, _ = _convert_rgb_to_3d(data_input)
data_format = "rgba"
return data, data_format


def _convert_2d_to_3d(data_input):
data = data_input[np.newaxis, :, :]
data_format = "2d"
return data, data_format


def _revert_3d_to_rgb(data_input):
data = data_input[0]
data = np.dstack((data, data, data))
return data


def _revert_3d_to_rgba(data_input):
data = data_input[0]
data = np.dstack((data, data, data, np.ones_like(data)))
return data


def _revert_3d_to_2d(data_input):
return data_input[0]
10 changes: 6 additions & 4 deletions qpretrieve/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
fft_shape: tuple of int
The shape of the Fourier transformed image for which the
The shape of the Fourier transformed image (2d) for which the
filter will be applied. The shape must be squared (two
identical integers).

Expand Down Expand Up @@ -104,8 +104,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
# TODO: avoid the np.roll, instead use the indices directly
alpha = 0.1
rsize = int(min(fx.size, fy.size) * filter_size) * 2
tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1)
tukey_window_x = signal.windows.tukey(
rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.windows.tukey(
rsize, alpha=alpha).reshape(1, -1)
tukey = tukey_window_x * tukey_window_y
base = np.zeros(fft_shape)
s1 = (np.array(fft_shape) - rsize) // 2
Expand Down
13 changes: 13 additions & 0 deletions qpretrieve/fourier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
PREFERRED_INTERFACE = None


def get_available_interfaces():
"""Return a list of available FFT algorithms"""
interfaces = [
FFTFilterPyFFTW,
FFTFilterNumpy,
]
interfaces_available = []
for interface in interfaces:
if interface is not None and interface.is_available:
interfaces_available.append(interface)
return interfaces_available


def get_best_interface():
"""Return the fastest refocusing interface available

Expand Down
53 changes: 36 additions & 17 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from .. import filter
from ..utils import padding_3d, mean_3d
from ..data_input import check_data_input_format


class FFTCache:
Expand Down Expand Up @@ -35,12 +37,19 @@ def cleanup(key):


class FFTFilter(ABC):
def __init__(self, data, subtract_mean=True, padding=2, copy=True):
def __init__(self,
data: np.ndarray,
subtract_mean: bool = True,
padding: int = 2,
copy: bool = True):
r"""
Parameters
----------
data: 2d real-valued np.ndarray
The experimental input image
data
The experimental input real-valued image. Allowed input shapes are:
- 2d (y, x)
- 3d (z, y, x)
- 3d rgb (y, x, 3) or rgba (y, x, 4)
subtract_mean: bool
If True, subtract the mean of `data` before performing
the Fourier transform. This setting is recommended as it
Expand Down Expand Up @@ -70,9 +79,15 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
else:
# convert integer-arrays to floating point arrays
dtype = float
if not copy:
# numpy v2.x behaviour requires asarray with copy=False
copy = None
data_ed = np.array(data, dtype=dtype, copy=copy)
# figure out what type of data we have
data_ed, self.data_format = check_data_input_format(data_ed)
#: original data (with subtracted mean)
self.origin = data_ed
# for `subtract_mean` and `padding`, we could use `np.atleast_3d`
#: whether padding is enabled
self.padding = padding
#: whether the mean was subtracted
Expand All @@ -81,14 +96,13 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
# remove contributions of the central band
# (this affects more than one pixel in the FFT
# because of zero-padding)
data_ed -= data_ed.mean()
data_ed = mean_3d(data_ed)
if padding:
# zero padding size is next order of 2
logfact = np.log(padding * max(data_ed.shape))
order = int(2 ** np.ceil(logfact / np.log(2)))
# this is faster than np.pad
datapad = np.zeros((order, order), dtype=dtype)
datapad[:data_ed.shape[0], :data_ed.shape[1]] = data_ed

datapad = padding_3d(data_ed, order, dtype)
#: padded input data
self.origin_padded = datapad
data_ed = datapad
Expand Down Expand Up @@ -175,7 +189,7 @@ def filter(self, filter_name: str, filter_size: float,
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
scale_to_filter: bool or float
Crop the image in Fourier space after applying the filter,
effectively removing surplus (zero-padding) data and
Expand Down Expand Up @@ -220,36 +234,41 @@ def filter(self, filter_name: str, filter_size: float,
filter_name=filter_name,
filter_size=filter_size,
freq_pos=freq_pos,
fft_shape=self.fft_origin.shape)
# only take shape of a single fft
fft_shape=self.fft_origin.shape[-2:])
fft_filtered = self.fft_origin * filt_array
px = int(freq_pos[0] * self.shape[0])
py = int(freq_pos[1] * self.shape[1])
fft_used = np.roll(np.roll(fft_filtered, -px, axis=0), -py, axis=1)
px = int(freq_pos[0] * self.shape[-2])
py = int(freq_pos[1] * self.shape[-1])
fft_used = np.roll(np.roll(
fft_filtered, -px, axis=-2), -py, axis=-1)
if scale_to_filter:
# Determine the size of the cropping region.
# We compute the "radius" of the region, so we can
# crop the data left and right from the center of the
# Fourier domain.
osize = fft_filtered.shape[0] # square shaped
osize = fft_filtered.shape[-2] # square shaped
crad = int(np.ceil(filter_size * osize * scale_to_filter))
ccent = osize // 2
cslice = slice(ccent - crad, ccent + crad)
# We now have the interesting peak already shifted to
# the first entry of our array in `shifted`.
fft_used = fft_used[cslice, cslice]
fft_used = fft_used[:, cslice, cslice]

field = self._ifft(np.fft.ifftshift(fft_used))

if self.padding:
# revert padding
sx, sy = self.origin.shape
sx, sy = self.origin.shape[-2:]
if scale_to_filter:
sx = int(np.ceil(sx * 2 * crad / osize))
sy = int(np.ceil(sy * 2 * crad / osize))
field = field[:sx, :sy]

field = field[:, :sx, :sy]

if scale_to_filter:
# Scale the absolute value of the field. This does not
# have any influence on the phase, but on the amplitude.
field *= (2 * crad / osize)**2
field *= (2 * crad / osize) ** 2
# Add FFT to cache
# (The cache will only be cleared if this instance is deleted)
FFTCache.add_item(weakref_key, self.fft_origin,
Expand Down
Loading
Loading