Skip to content

Commit

Permalink
docs: add type hints for utils module; ref: make 2d funcs private
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 24, 2025
1 parent 880591a commit b64aa82
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
10 changes: 6 additions & 4 deletions qpretrieve/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from scipy import signal


available_filters = [
"disk",
"smooth disk",
Expand All @@ -15,7 +14,10 @@


@lru_cache(maxsize=32)
def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
def get_filter_array(
filter_name: str, filter_size: float,
freq_pos: tuple[float, float],
fft_shape: tuple[int, int]) -> np.ndarray:
"""Create a Fourier filter for holography
Parameters
Expand Down Expand Up @@ -55,15 +57,15 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
raise ValueError("The Fourier transformed data must have a squared "
+ f"shape, but the input shape is '{fft_shape}'! "
+ "Please pad your data properly before FFT.")
if not (0 < filter_size < max(fft_shape)/2):
if not (0 < filter_size < max(fft_shape) / 2):
raise ValueError("The filter size cannot exceed more than half of "
+ "the Fourier space or be negative. Got a filter "
+ f"size of '{filter_size}' and a shape of "
+ f"'{fft_shape}'!")
if not (0
<= min(np.abs(freq_pos))
<= max(np.abs(freq_pos))
< max(fft_shape)/2):
< max(fft_shape) / 2):
raise ValueError("The frequency position must be within the Fourier "
+ f"domain. Got '{freq_pos}' and shape "
+ f"'{fft_shape}'!")
Expand Down
24 changes: 19 additions & 5 deletions qpretrieve/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
import numpy as np


def mean_2d(data):
def _mean_2d(data):
"""Exists for testing against mean_3d"""
data -= data.mean()
return data


def mean_3d(data):
# calculate mean of the images along the z-axis.
def mean_3d(data: np.ndarray) -> np.ndarray:
"""Calculate mean of the data along the z-axis."""
# The mean array here is (1000,), so we need to add newaxes for subtraction
# (1000, 5, 5) -= (1000, 1, 1)
data -= data.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis]
return data


def padding_2d(data, order, dtype):
def _padding_2d(data, order, dtype):
"""Exists for testing against padding_3d"""
# this is faster than np.pad
datapad = np.zeros((order, order), dtype=dtype)
# we could of course use np.atleast_3d here
datapad[:data.shape[0], :data.shape[1]] = data
return datapad


def padding_3d(data, order, dtype):
def padding_3d(data: np.ndarray, order: int, dtype: np.dtype) -> np.ndarray:
"""Calculate padding of the data along the z-axis.
Parameters
----------
data
3d array. The padding will be applied to the axes (y,x) only.
order
The data will be padded to this size.
dtype
data type of the padded array.
"""
z, y, x = data.shape
# this is faster than np.pad
datapad = np.zeros((z, order, order), dtype=dtype)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np

from qpretrieve.utils import padding_2d, padding_3d, mean_2d, mean_3d
from qpretrieve.utils import _padding_2d, padding_3d, _mean_2d, mean_3d


def test_mean_subtraction():
data_3d = np.random.rand(1000, 5, 5).astype(np.float32)
ind = 5
data_2d = data_3d.copy()[ind]

data_2d = mean_2d(data_2d)
data_2d = _mean_2d(data_2d)
data_3d = mean_3d(data_3d)

assert np.array_equal(data_3d[ind], data_2d)
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_batch_padding():
order = 512
dtype = float

data_2d_padded = padding_2d(data_2d, order, dtype)
data_2d_padded = _padding_2d(data_2d, order, dtype)
data_3d_padded = padding_3d(data_3d, order, dtype)

assert np.array_equal(data_3d_padded[ind], data_2d_padded)

0 comments on commit b64aa82

Please sign in to comment.