Skip to content

Commit

Permalink
tests: ensure new array processing results in identical values
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 7, 2025
1 parent e9a623a commit 284ebe1
Show file tree
Hide file tree
Showing 10 changed files with 411 additions and 67 deletions.
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ To install the requirements for building the documentation, run

To compile the documentation, run

cd docs
sphinx-build . _build


Expand Down
7 changes: 3 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
sphinx==4.3.0
sphinxcontrib.bibtex>=2.0
sphinx_rtd_theme==1.0

sphinx
sphinxcontrib.bibtex
sphinx_rtd_theme
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
matplotlib
23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import shutil
import tempfile
import time
import numpy as np

import qpretrieve
import pytest

import qpretrieve

TMPDIR = tempfile.mkdtemp(prefix=time.strftime(
"qpretrieve_test_%H.%M_"))
Expand All @@ -22,3 +24,22 @@ def pytest_configure(config):
# creating FFTW wisdom. Also, it makes the tests more reproducible
# by sticking to simple numpy FFTs.
qpretrieve.fourier.PREFERRED_INTERFACE = "FFTFilterNumpy"


@pytest.fixture
def hologram(size=64):
x = np.arange(size).reshape(-1, 1) - size / 2
y = np.arange(size).reshape(1, -1) - size / 2

amp = np.linspace(.9, 1.1, size * size).reshape(size, size)
pha = np.linspace(0, 2, size * size).reshape(size, size)

rad = x ** 2 + y ** 2 > (size / 3) ** 2
pha[rad] = 0
amp[rad] = 1

# frequencies must match pixel in Fourier space
kx = 2 * np.pi * -.3
ky = 2 * np.pi * -.3
image = (amp ** 2 + np.sin(kx * x + ky * y + pha) + 1) * 255
return image
46 changes: 46 additions & 0 deletions tests/test_data_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pytest

from qpretrieve.data_input import check_data_input_format


def test_check_data_input_2d():
data = np.zeros(shape=(256, 256))

data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data)
assert data_format == "2d"


def test_check_data_input_3d_image_stack():
data = np.zeros(shape=(50, 256, 256))

data_new, data_format = check_data_input_format(data)

assert data_new.shape == (50, 256, 256)
assert np.array_equal(data_new, data)
assert data_format == "3d"


def test_check_data_input_3d_rgb():
data = np.zeros(shape=(256, 256, 3))

with pytest.warns(UserWarning):
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data[:, :, 0])
assert data_format == "rgb"


def test_check_data_input_3d_rgba():
data = np.zeros(shape=(256, 256, 4))

with pytest.warns(UserWarning):
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data[:, :, 0])
assert data_format == "rgba"
17 changes: 14 additions & 3 deletions tests/test_fourier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,23 @@ def test_scale_to_filter_qlsi():
}

ifh = interfere.QLSInterferogram(image, **pipeline_kws)
ifh.run_pipeline()
raw_wavefront = ifh.run_pipeline()
assert raw_wavefront.shape == (1, 720, 720)
assert ifh.phase.shape == (1, 720, 720)
assert ifh.amplitude.shape == (1, 720, 720)
assert ifh.field.shape == (1, 720, 720)

ifr = interfere.QLSInterferogram(refer, **pipeline_kws)
ifr.run_pipeline()
assert ifr.phase.shape == (1, 720, 720)
assert ifr.amplitude.shape == (1, 720, 720)
assert ifr.field.shape == (1, 720, 720)

ifh_phase = ifh.phase[0]
ifr_phase = ifr.phase[0]

phase = unwrap_phase(ifh_phase - ifr_phase)

phase = unwrap_phase(ifh.phase - ifr.phase)
assert phase.shape == (720, 720)
assert np.allclose(phase.mean(), 0.12434563269684816, atol=1e-6)

Expand All @@ -152,6 +163,6 @@ def test_scale_to_filter_qlsi():
ifr.run_pipeline(**pipeline_kws_scale)
phase_scaled = unwrap_phase(ifh.phase - ifr.phase)

assert phase_scaled.shape == (126, 126)
assert phase_scaled.shape == (1, 126, 126)

assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6)
40 changes: 40 additions & 0 deletions tests/test_interfere_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pathlib
import numpy as np
import pytest

import qpretrieve

data_path = pathlib.Path(__file__).parent / "data"


def test_interfere_base_best_interface():
edata = np.load(data_path / "hologram_cell.npz")

holo = qpretrieve.OffAxisHologram(data=edata["data"])
assert holo.ff_iface.is_available
assert issubclass(holo.ff_iface,
qpretrieve.fourier.base.FFTFilter)
assert issubclass(holo.ff_iface,
qpretrieve.fourier.ff_numpy.FFTFilterNumpy)


def test_interfere_base_choose_interface():
edata = np.load(data_path / "hologram_cell.npz")

holo = qpretrieve.OffAxisHologram(
data=edata["data"],
fft_interface=qpretrieve.fourier.FFTFilterNumpy)
assert holo.ff_iface.is_available
assert issubclass(holo.ff_iface,
qpretrieve.fourier.base.FFTFilter)
assert issubclass(holo.ff_iface,
qpretrieve.fourier.ff_numpy.FFTFilterNumpy)


def test_interfere_base_bad_interface():
edata = np.load(data_path / "hologram_cell.npz")

with pytest.raises(ValueError):
_ = qpretrieve.OffAxisHologram(
data=edata["data"],
fft_interface="MyReallyCoolFFTInterface")
Loading

0 comments on commit 284ebe1

Please sign in to comment.