From 284ebe10db4860cf526302ee9f2a94c809f555f9 Mon Sep 17 00:00:00 2001 From: Eoghan O'Connell Date: Tue, 7 Jan 2025 14:34:27 +0100 Subject: [PATCH] tests: ensure new array processing results in identical values --- docs/README.md | 1 + docs/requirements.txt | 7 +- examples/requirements.txt | 1 + tests/conftest.py | 23 ++++- tests/test_data_input.py | 46 ++++++++++ tests/test_fourier_base.py | 17 +++- tests/test_interfere_base.py | 40 ++++++++ tests/test_oah_from_qpimage.py | 161 +++++++++++++++++++++------------ tests/test_qlsi.py | 138 +++++++++++++++++++++++++++- tests/test_utils.py | 44 +++++++++ 10 files changed, 411 insertions(+), 67 deletions(-) create mode 100644 examples/requirements.txt create mode 100644 tests/test_data_input.py create mode 100644 tests/test_interfere_base.py create mode 100644 tests/test_utils.py diff --git a/docs/README.md b/docs/README.md index 694748f..e8441e1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,6 +6,7 @@ To install the requirements for building the documentation, run To compile the documentation, run + cd docs sphinx-build . _build diff --git a/docs/requirements.txt b/docs/requirements.txt index 15408de..87e4bc3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ -sphinx==4.3.0 -sphinxcontrib.bibtex>=2.0 -sphinx_rtd_theme==1.0 - +sphinx +sphinxcontrib.bibtex +sphinx_rtd_theme diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..6ccafc3 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1 @@ +matplotlib diff --git a/tests/conftest.py b/tests/conftest.py index 759db0d..cdbc989 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,11 @@ import shutil import tempfile import time +import numpy as np -import qpretrieve +import pytest +import qpretrieve TMPDIR = tempfile.mkdtemp(prefix=time.strftime( "qpretrieve_test_%H.%M_")) @@ -22,3 +24,22 @@ def pytest_configure(config): # creating FFTW wisdom. Also, it makes the tests more reproducible # by sticking to simple numpy FFTs. qpretrieve.fourier.PREFERRED_INTERFACE = "FFTFilterNumpy" + + +@pytest.fixture +def hologram(size=64): + x = np.arange(size).reshape(-1, 1) - size / 2 + y = np.arange(size).reshape(1, -1) - size / 2 + + amp = np.linspace(.9, 1.1, size * size).reshape(size, size) + pha = np.linspace(0, 2, size * size).reshape(size, size) + + rad = x ** 2 + y ** 2 > (size / 3) ** 2 + pha[rad] = 0 + amp[rad] = 1 + + # frequencies must match pixel in Fourier space + kx = 2 * np.pi * -.3 + ky = 2 * np.pi * -.3 + image = (amp ** 2 + np.sin(kx * x + ky * y + pha) + 1) * 255 + return image diff --git a/tests/test_data_input.py b/tests/test_data_input.py new file mode 100644 index 0000000..bd04222 --- /dev/null +++ b/tests/test_data_input.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +from qpretrieve.data_input import check_data_input_format + + +def test_check_data_input_2d(): + data = np.zeros(shape=(256, 256)) + + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data) + assert data_format == "2d" + + +def test_check_data_input_3d_image_stack(): + data = np.zeros(shape=(50, 256, 256)) + + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (50, 256, 256) + assert np.array_equal(data_new, data) + assert data_format == "3d" + + +def test_check_data_input_3d_rgb(): + data = np.zeros(shape=(256, 256, 3)) + + with pytest.warns(UserWarning): + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert data_format == "rgb" + + +def test_check_data_input_3d_rgba(): + data = np.zeros(shape=(256, 256, 4)) + + with pytest.warns(UserWarning): + data_new, data_format = check_data_input_format(data) + + assert data_new.shape == (1, 256, 256) + assert np.array_equal(data_new[0], data[:, :, 0]) + assert data_format == "rgba" diff --git a/tests/test_fourier_base.py b/tests/test_fourier_base.py index c364ebb..2f58ea0 100644 --- a/tests/test_fourier_base.py +++ b/tests/test_fourier_base.py @@ -136,12 +136,23 @@ def test_scale_to_filter_qlsi(): } ifh = interfere.QLSInterferogram(image, **pipeline_kws) - ifh.run_pipeline() + raw_wavefront = ifh.run_pipeline() + assert raw_wavefront.shape == (1, 720, 720) + assert ifh.phase.shape == (1, 720, 720) + assert ifh.amplitude.shape == (1, 720, 720) + assert ifh.field.shape == (1, 720, 720) ifr = interfere.QLSInterferogram(refer, **pipeline_kws) ifr.run_pipeline() + assert ifr.phase.shape == (1, 720, 720) + assert ifr.amplitude.shape == (1, 720, 720) + assert ifr.field.shape == (1, 720, 720) + + ifh_phase = ifh.phase[0] + ifr_phase = ifr.phase[0] + + phase = unwrap_phase(ifh_phase - ifr_phase) - phase = unwrap_phase(ifh.phase - ifr.phase) assert phase.shape == (720, 720) assert np.allclose(phase.mean(), 0.12434563269684816, atol=1e-6) @@ -152,6 +163,6 @@ def test_scale_to_filter_qlsi(): ifr.run_pipeline(**pipeline_kws_scale) phase_scaled = unwrap_phase(ifh.phase - ifr.phase) - assert phase_scaled.shape == (126, 126) + assert phase_scaled.shape == (1, 126, 126) assert np.allclose(phase_scaled.mean(), 0.1257080793074251, atol=1e-6) diff --git a/tests/test_interfere_base.py b/tests/test_interfere_base.py new file mode 100644 index 0000000..dec2027 --- /dev/null +++ b/tests/test_interfere_base.py @@ -0,0 +1,40 @@ +import pathlib +import numpy as np +import pytest + +import qpretrieve + +data_path = pathlib.Path(__file__).parent / "data" + + +def test_interfere_base_best_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram(data=edata["data"]) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_choose_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + holo = qpretrieve.OffAxisHologram( + data=edata["data"], + fft_interface=qpretrieve.fourier.FFTFilterNumpy) + assert holo.ff_iface.is_available + assert issubclass(holo.ff_iface, + qpretrieve.fourier.base.FFTFilter) + assert issubclass(holo.ff_iface, + qpretrieve.fourier.ff_numpy.FFTFilterNumpy) + + +def test_interfere_base_bad_interface(): + edata = np.load(data_path / "hologram_cell.npz") + + with pytest.raises(ValueError): + _ = qpretrieve.OffAxisHologram( + data=edata["data"], + fft_interface="MyReallyCoolFFTInterface") diff --git a/tests/test_oah_from_qpimage.py b/tests/test_oah_from_qpimage.py index 838b180..e80bd33 100644 --- a/tests/test_oah_from_qpimage.py +++ b/tests/test_oah_from_qpimage.py @@ -4,24 +4,10 @@ import qpretrieve from qpretrieve.interfere import if_oah - - -def hologram(size=64): - x = np.arange(size).reshape(-1, 1) - size / 2 - y = np.arange(size).reshape(1, -1) - size / 2 - - amp = np.linspace(.9, 1.1, size * size).reshape(size, size) - pha = np.linspace(0, 2, size * size).reshape(size, size) - - rad = x**2 + y**2 > (size / 3)**2 - pha[rad] = 0 - amp[rad] = 1 - - # frequencies must match pixel in Fourier space - kx = 2 * np.pi * -.3 - ky = 2 * np.pi * -.3 - image = (amp**2 + np.sin(kx * x + ky * y + pha) + 1) * 255 - return image +from qpretrieve.fourier import FFTFilterNumpy, FFTFilterPyFFTW +from qpretrieve.data_input import ( + _convert_2d_to_3d, _revert_3d_to_rgb, _revert_3d_to_rgba, +) def test_find_sideband(): @@ -36,24 +22,26 @@ def test_find_sideband(): def test_fourier2dpad(): - data = np.zeros((100, 120)) + y, x = 100, 120 + data = np.zeros((y, x)) fft1 = qpretrieve.fourier.FFTFilterNumpy(data) - assert fft1.shape == (256, 256) + assert fft1.shape == (1, 256, 256) fft2 = qpretrieve.fourier.FFTFilterNumpy(data, padding=False) - assert fft2.shape == data.shape + assert fft2.shape == (1, y, x) -def test_get_field_error_bad_filter_size(): - data = hologram() +def test_get_field_error_bad_filter_size(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, match="must be between 0 and 1"): holo.run_pipeline(filter_size=2) -def test_get_field_error_bad_filter_size_interpretation_frequency_index(): - data = hologram(size=64) +def test_get_field_error_bad_filter_size_interpretation_frequency_index( + hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, @@ -62,8 +50,8 @@ def test_get_field_error_bad_filter_size_interpretation_frequency_index(): filter_size=64) -def test_get_field_error_invalid_interpretation(): - data = hologram() +def test_get_field_error_invalid_interpretation(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) with pytest.raises(ValueError, @@ -71,8 +59,8 @@ def test_get_field_error_invalid_interpretation(): holo.run_pipeline(filter_size_interpretation="blequency") -def test_get_field_filter_names(): - data = hologram() +def test_get_field_filter_names(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs = dict(sideband=+1, @@ -84,11 +72,11 @@ def test_get_field_filter_names(): r_smooth_disk = holo.run_pipeline(filter_name="smooth disk", **kwargs) assert np.allclose(r_smooth_disk[32, 32], - 108.36438759594623-67.1806221692573j) + 108.36438759594623 - 67.1806221692573j) r_gauss = holo.run_pipeline(filter_name="gauss", **kwargs) assert np.allclose(r_gauss[32, 32], - 108.2914187451138-67.1823527237741j) + 108.2914187451138 - 67.1823527237741j) r_square = holo.run_pipeline(filter_name="square", **kwargs) assert np.allclose( @@ -96,7 +84,7 @@ def test_get_field_filter_names(): r_smsquare = holo.run_pipeline(filter_name="smooth square", **kwargs) assert np.allclose( - r_smsquare[32, 32], 108.36651862466393-67.17988960794392j) + r_smsquare[32, 32], 108.36651862466393 - 67.17988960794392j) r_tukey = holo.run_pipeline(filter_name="tukey", **kwargs) assert np.allclose( @@ -110,10 +98,10 @@ def test_get_field_filter_names(): assert False, "unknown filter accepted" -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index(size): +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=["hologram"]) +def test_get_field_interpretation_fourier_index(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) ft_data = holo.fft_origin @@ -121,23 +109,26 @@ def test_get_field_interpretation_fourier_index(size): fsx, fsy = holo.pipeline_kws["sideband_freq"] kwargs1 = dict(filter_name="disk", - filter_size=1/3, + filter_size=1 / 3, filter_size_interpretation="sideband distance") res1 = holo.run_pipeline(**kwargs1) - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] kwargs2 = dict(filter_name="disk", filter_size=filter_size_fi, filter_size_interpretation="frequency index", ) res2 = holo.run_pipeline(**kwargs2) + + assert res1.shape == hologram.shape + assert res2.shape == hologram.shape assert np.all(res1 == res2) -@pytest.mark.parametrize("size", [62, 63, 64]) -def test_get_field_interpretation_fourier_index_control(size): +@pytest.mark.parametrize("hologram", [62, 63, 64], indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_control(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) ft_data = holo.fft_origin @@ -147,12 +138,12 @@ def test_get_field_interpretation_fourier_index_control(size): evil_factor = 1.1 kwargs1 = dict(filter_name="disk", - filter_size=1/3 * evil_factor, + filter_size=1 / 3 * evil_factor, filter_size_interpretation="sideband distance" ) res1 = holo.run_pipeline(**kwargs1) - filter_size_fi = np.sqrt(fsx**2 + fsy**2) / 3 * ft_data.shape[0] + filter_size_fi = np.sqrt(fsx ** 2 + fsy ** 2) / 3 * ft_data.shape[-2] kwargs2 = dict(filter_name="disk", filter_size=filter_size_fi, filter_size_interpretation="frequency index", @@ -161,11 +152,12 @@ def test_get_field_interpretation_fourier_index_control(size): assert not np.all(res1 == res2) -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) @pytest.mark.parametrize("filter_size", [17, 17.01]) -def test_get_field_interpretation_fourier_index_mask_1(size, filter_size): +def test_get_field_interpretation_fourier_index_mask_1(hologram, filter_size): """Make sure filter size in Fourier space pixels is correct""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs2 = dict(filter_name="disk", @@ -178,13 +170,14 @@ def test_get_field_interpretation_fourier_index_mask_1(size, filter_size): # We get 17*2+1, because we measure from the center of Fourier # space and a pixel is included if its center is withing the # perimeter of the disk. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 + 1 + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 + 1 -@pytest.mark.parametrize("size", [62, 63, 64, 134, 135]) -def test_get_field_interpretation_fourier_index_mask_2(size): +@pytest.mark.parametrize("hologram", [62, 63, 64, 134, 135], + indirect=["hologram"]) +def test_get_field_interpretation_fourier_index_mask_2(hologram): """Filter size in Fourier space using Fourier index new in 0.7.0""" - data = hologram(size=size) + data = hologram holo = qpretrieve.OffAxisHologram(data) kwargs2 = dict(filter_name="disk", @@ -196,11 +189,11 @@ def test_get_field_interpretation_fourier_index_mask_2(size): # We get two points less than in the previous test, because we # loose on each side of the spectrum. - assert np.sum(np.sum(mask, axis=0) != 0) == 17*2 - 1 + assert np.sum(np.sum(mask, axis=-2) != 0) == 17 * 2 - 1 -def test_get_field_int_copy(): - data = hologram() +def test_get_field_int_copy(hologram): + data = hologram data = np.array(data, dtype=int) kwargs = dict(filter_size=1 / 3) @@ -218,8 +211,8 @@ def test_get_field_int_copy(): assert np.all(res1 == res3) -def test_get_field_sideband(): - data = hologram() +def test_get_field_sideband(hologram): + data = hologram holo = qpretrieve.OffAxisHologram(data) holo.run_pipeline() invert_phase = holo.pipeline_kws["invert_phase"] @@ -232,10 +225,10 @@ def test_get_field_sideband(): assert np.all(res1 == res2) -def test_get_field_three_axes(): - data1 = hologram() +def test_get_field_three_axes(hologram): + data1 = hologram # create a copy with empty entry in third axis - data2 = np.zeros((data1.shape[0], data1.shape[1], 2)) + data2 = np.zeros((data1.shape[0], data1.shape[1], 3)) data2[:, :, 0] = data1 holo1 = qpretrieve.OffAxisHologram(data1) @@ -245,4 +238,56 @@ def test_get_field_three_axes(): filter_size=1 / 3) res1 = holo1.run_pipeline(**kwargs) res2 = holo2.run_pipeline(**kwargs) - assert np.all(res1 == res2) + + assert res1.shape == (data1.shape[0], data1.shape[1]) + assert res2.shape == (data1.shape[0], data1.shape[1], 3) + + assert np.all(res1 == res2[:, :, 0]) + + +def test_get_field_compare_FFTFilters(hologram): + data1 = hologram + + holo1 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterNumpy, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res1 = holo1.run_pipeline(**kwargs) + assert res1.shape == (64, 64) + + holo2 = qpretrieve.OffAxisHologram(data1, + fft_interface=FFTFilterPyFFTW, + padding=False) + kwargs = dict(filter_name="disk", filter_size=1 / 3) + res2 = holo2.run_pipeline(**kwargs) + assert res2.shape == (64, 64) + + assert not np.all(res1 == res2) + + +def test_field_format_consistency(hologram): + """The data format provided by the user should be returned""" + data_2d = hologram + + # 2d data format + holo1 = qpretrieve.OffAxisHologram(data_2d) + res1 = holo1.run_pipeline() + assert res1.shape == data_2d.shape + + # 3d data format + data_3d, _ = _convert_2d_to_3d(data_2d) + holo_3d = qpretrieve.OffAxisHologram(data_3d) + res_3d = holo_3d.run_pipeline() + assert res_3d.shape == data_3d.shape + + # rgb data format + data_rgb = _revert_3d_to_rgb(data_3d) + holo_rgb = qpretrieve.OffAxisHologram(data_rgb) + res_rgb = holo_rgb.run_pipeline() + assert res_rgb.shape == data_rgb.shape + + # rgba data format + data_rgba = _revert_3d_to_rgba(data_3d) + holo_rgba = qpretrieve.OffAxisHologram(data_rgba) + res_rgba = holo_rgba.run_pipeline() + assert res_rgba.shape == data_rgba.shape diff --git a/tests/test_qlsi.py b/tests/test_qlsi.py index f839114..c541919 100644 --- a/tests/test_qlsi.py +++ b/tests/test_qlsi.py @@ -2,8 +2,9 @@ import h5py import numpy as np -import qpretrieve +from skimage.restoration import unwrap_phase +import qpretrieve data_path = pathlib.Path(__file__).parent / "data" @@ -29,3 +30,138 @@ def test_qlsi_phase(): assert qlsi.phase.argmax() == 242294 assert np.allclose(qlsi.phase.max(), 0.9343997734657971, atol=0, rtol=1e-12) + + +def test_qlsi_fftfreq_reshape_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + fx_2d = np.fft.fftfreq(data_2d.shape[-1]).reshape(-1, 1) + fx_3d = np.fft.fftfreq(data_3d.shape[-1]).reshape(data_3d.shape[0], -1, 1) + + fy_2d = np.fft.fftfreq(data_2d.shape[-2]).reshape(1, -1) + fy_3d = np.fft.fftfreq(data_3d.shape[-2]).reshape(data_3d.shape[0], 1, -1) + + assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fy_2d, fy_3d[0]) + + +def test_qlsi_unwrap_phase_2d_3d(): + """ + Check whether skimage unwrap_2d and unwrap_3d give the same result. + In other words, does unwrap_3d apply th unwrapping along the z axis. + + Answer is no, they are different. unwrap_3d is designed for 3D data that + is to be unwrapped on all axes at once. + + """ + with h5py.File(data_path / "qlsi_paa_bead.h5") as h5: + image = h5["0"][:] + + # Standard analysis pipeline + pipeline_kws = { + 'wavelength': 550e-9, + 'qlsi_pitch_term': 1.87711e-08, + 'filter_name': "disk", + 'filter_size': 180, + 'filter_size_interpretation': "frequency index", + 'scale_to_filter': False, + 'invert_phase': False + } + + data_2d = image + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + ft_2d = qpretrieve.fourier.FFTFilterNumpy(data_2d, subtract_mean=False) + ft_3d = qpretrieve.fourier.FFTFilterNumpy(data_3d, subtract_mean=False) + + pipeline_kws["sideband_freq"] = qpretrieve.interfere. \ + if_qlsi.find_peaks_qlsi(ft_2d.fft_origin[0]) + + hx_2d = ft_2d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + hx_3d = ft_3d.filter(filter_name=pipeline_kws["filter_name"], + filter_size=pipeline_kws["filter_size"], + scale_to_filter=pipeline_kws["scale_to_filter"], + freq_pos=pipeline_kws["sideband_freq"]) + + assert np.array_equal(hx_2d, hx_3d) + + px_2d = unwrap_phase(np.angle(hx_2d[0])) + + px_3d_loop = np.zeros_like(hx_3d) + for i, _hx in enumerate(hx_3d): + px_3d_loop[i] = unwrap_phase(np.angle(_hx)) + + assert np.array_equal(px_2d, px_3d_loop[0]) # this passes + + px_3d = unwrap_phase(np.angle(hx_3d)) # this is not equivalent + assert not np.array_equal(px_2d, px_3d[0]) + + +def test_qlsi_rotate_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + rot_2d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_2d, + angle=2, + axes=(1, 0), # this was the default used before + reshape=False, + ) + rot_3d = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-1, -2), # the y and x axes + reshape=False, + ) + rot_3d_2 = qpretrieve.interfere.if_qlsi.rotate_noreshape( + data_3d, + angle=2, + axes=(-2, -1), # the y and x axes + reshape=False, + ) + + assert rot_2d.dtype == rot_3d.dtype + assert np.array_equal(rot_2d, rot_3d[0]) + assert np.array_equal(rot_2d, rot_3d_2[0]) + + +def test_qlsi_pad_2d_3d(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + sx, sy = data_2d.shape[-2:] + gradpad_2d = np.pad( + data_2d, ((sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + gradpad_3d = np.pad( + data_3d, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)), + mode="constant", constant_values=0) + + assert gradpad_2d.dtype == gradpad_3d.dtype + assert np.array_equal(gradpad_2d, gradpad_3d[0]) + + +def test_fxy_complex_mul(hologram): + data_2d = hologram + data_3d, _ = qpretrieve.data_input._convert_2d_to_3d(data_2d) + + assert np.array_equal(data_2d, data_3d[0]) + + # 2d + fx_2d = data_2d.reshape(-1, 1) + fy_2d = data_2d.reshape(1, -1) + fxy_2d = -2 * np.pi * 1j * (fx_2d + 1j * fy_2d) + fxy_2d[0, 0] = 1 + + # 3d + fx_3d = data_3d.reshape(data_3d.shape[0], -1, 1) + fy_3d = data_3d.reshape(data_3d.shape[0], 1, -1) + fxy_3d = -2 * np.pi * 1j * (fx_3d + 1j * fy_3d) + fxy_3d[:, 0, 0] = 1 + + assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fxy_2d, fxy_3d[0]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c7be6d5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,44 @@ +import numpy as np + +from qpretrieve.utils import padding_2d, padding_3d, mean_2d, mean_3d + + +def test_mean_subtraction(): + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + data_2d = mean_2d(data_2d) + data_3d = mean_3d(data_3d) + + assert np.array_equal(data_3d[ind], data_2d) + + +def test_mean_subtraction_consistent_2d_3d(): + """Probably a bit too cumbersome, and changes the default 2d pipeline.""" + data_3d = np.random.rand(1000, 5, 5).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + + # too cumbersome + data_2d = np.atleast_3d(data_2d) + data_2d = np.swapaxes(np.swapaxes(data_2d, 0, 2), 1, 2) + data_2d -= data_2d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + data_3d = np.atleast_3d(data_3d.copy()) + data_3d -= data_3d.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] + + assert np.array_equal(data_3d[ind], data_2d[0]) + + +def test_batch_padding(): + data_3d = np.random.rand(1000, 100, 320).astype(np.float32) + ind = 5 + data_2d = data_3d.copy()[ind] + order = 512 + dtype = float + + data_2d_padded = padding_2d(data_2d, order, dtype) + data_3d_padded = padding_3d(data_3d, order, dtype) + + assert np.array_equal(data_3d_padded[ind], data_2d_padded)