Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap-safe transfer functions #175

Merged
merged 6 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Literal

import numpy as np
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util


def generate_test_phantom(
Expand All @@ -28,6 +29,49 @@ def calculate_transfer_function(
index_of_refraction_media,
numerical_aperture_detection,
):

transverse_nyquist = sampling.transverse_nyquist(
wavelength_emission,
numerical_aperture_detection, # ill = det for fluorescence
numerical_aperture_detection,
)
axial_nyquist = sampling.axial_nyquist(
wavelength_emission,
numerical_aperture_detection,
index_of_refraction_media,
)

yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))

optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
(
zyx_shape[0] * z_factor,
zyx_shape[1] * yx_factor,
zyx_shape[2] * yx_factor,
),
yx_pixel_size / yx_factor,
z_pixel_size / z_factor,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
)
zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
return sampling.nd_fourier_central_cuboid(
optical_transfer_function, zyx_out_shape
)


def _calculate_wrap_unsafe_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
)
Expand Down Expand Up @@ -97,7 +141,7 @@ def apply_transfer_function(
Returns
-------
Simulated data : torch.Tensor

"""
if (
zyx_object.shape[0] + 2 * z_padding
Expand Down
61 changes: 60 additions & 1 deletion waveorder/models/isotropic_thin_3d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Literal, Tuple

import numpy as np
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util


def generate_test_phantom(
Expand Down Expand Up @@ -42,6 +43,64 @@ def calculate_transfer_function(
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
transverse_nyquist = sampling.transverse_nyquist(
wavelength_illumination,
numerical_aperture_illumination,
numerical_aperture_detection,
)
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))

absorption_2d_to_3d_transfer_function, phase_2d_to_3d_transfer_function = (
_calculate_wrap_unsafe_transfer_function(
(
yx_shape[0] * yx_factor,
yx_shape[1] * yx_factor,
),
yx_pixel_size / yx_factor,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=invert_phase_contrast,
)
)

absorption_2d_to_3d_transfer_function_out = torch.zeros(
(len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
)
phase_2d_to_3d_transfer_function_out = torch.zeros(
(len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
)

for z in range(len(z_position_list)):
absorption_2d_to_3d_transfer_function_out[z] = (
sampling.nd_fourier_central_cuboid(
absorption_2d_to_3d_transfer_function[z], yx_shape
)
)
phase_2d_to_3d_transfer_function_out[z] = (
sampling.nd_fourier_central_cuboid(
phase_2d_to_3d_transfer_function[z], yx_shape
)
)

return (
absorption_2d_to_3d_transfer_function_out,
phase_2d_to_3d_transfer_function_out,
)


def _calculate_wrap_unsafe_transfer_function(
yx_shape,
yx_pixel_size,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
if invert_phase_contrast:
z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
Expand Down
56 changes: 55 additions & 1 deletion waveorder/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util
from waveorder.models import isotropic_fluorescent_thick_3d


Expand Down Expand Up @@ -40,6 +40,60 @@ def calculate_transfer_function(
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
transverse_nyquist = sampling.transverse_nyquist(
wavelength_illumination,
numerical_aperture_illumination,
numerical_aperture_detection,
)
axial_nyquist = sampling.axial_nyquist(
wavelength_illumination,
numerical_aperture_detection,
index_of_refraction_media,
)

yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
ziw-liu marked this conversation as resolved.
Show resolved Hide resolved

real_potential_transfer_function, imag_potential_transfer_function = (
_calculate_wrap_unsafe_transfer_function(
(
zyx_shape[0] * z_factor,
zyx_shape[1] * yx_factor,
zyx_shape[2] * yx_factor,
),
yx_pixel_size / yx_factor,
z_pixel_size / z_factor,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=invert_phase_contrast,
)
)

zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
return (
sampling.nd_fourier_central_cuboid(
real_potential_transfer_function, zyx_out_shape
),
sampling.nd_fourier_central_cuboid(
imag_potential_transfer_function, zyx_out_shape
),
)


def _calculate_wrap_unsafe_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
Expand Down
94 changes: 94 additions & 0 deletions waveorder/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import torch


def transverse_nyquist(
wavelength_emission,
numerical_aperture_illumination,
numerical_aperture_detection,
):
"""Transverse Nyquist sample spacing in `wavelength_emission` units.

For widefield label-free imaging, the transverse Nyquist sample spacing is
lambda / (2 * (NA_ill + NA_det)).

Perhaps surprisingly, the transverse Nyquist sample spacing for widefield
fluorescence is lambda / (4 * NA), which is equivalent to the above formula
when NA_ill = NA_det.

Parameters
----------
wavelength_emission : float
Output units match these units
numerical_aperture_illumination : float
For widefield fluorescence, set to numerical_aperture_detection
numerical_aperture_detection : float

Returns
-------
float
Transverse Nyquist sample spacing

"""
return wavelength_emission / (
2 * (numerical_aperture_detection + numerical_aperture_illumination)
)


def axial_nyquist(
wavelength_emission,
numerical_aperture_detection,
index_of_refraction_media,
):
"""Axial Nyquist sample spacing in `wavelength_emission` units.

For widefield microscopes, the axial Nyquist cutoff frequency is:

(n/lambda) - sqrt( (n/lambda)^2 - (NA_det/lambda)^2 ),

and the axial Nyquist sample spacing is 1 / (2 * cutoff_frequency).

Perhaps surprisingly, the axial Nyquist sample spacing is independent of
the illumination numerical aperture.

Parameters
----------
wavelength_emission : float
Output units match these units
numerical_aperture_detection : float
index_of_refraction_media: float

Returns
-------
float
Axial Nyquist sample spacing

"""
n_on_lambda = index_of_refraction_media / wavelength_emission
cutoff_frequency = n_on_lambda - np.sqrt(
n_on_lambda**2
- (numerical_aperture_detection / wavelength_emission) ** 2
)
return 1 / (2 * cutoff_frequency)


def nd_fourier_central_cuboid(source, target_shape):
"""Central cuboid of an N-D Fourier transform.

Parameters
----------
source : torch.Tensor
Source tensor
target_shape : tuple of int

Returns
-------
torch.Tensor
Center cuboid in Fourier space

"""
center_slices = tuple(
slice((s - o) // 2, (s - o) // 2 + o)
for s, o in zip(source.shape, target_shape)
)
return torch.fft.ifftshift(torch.fft.fftshift(source)[center_slices])
Loading