Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector reconstruction CLI prototype #484

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
46 changes: 45 additions & 1 deletion recOrder/cli/apply_inverse_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np
import torch
from waveorder.models import (
inplane_oriented_thick_pol3d_vector,
inplane_oriented_thick_pol3d,
isotropic_fluorescent_thick_3d,
isotropic_thin_3d,
phase_thick_3d,
)
from waveorder.stokes import stokes_after_adr, _s12_to_orientation


def radians_to_nanometers(retardance_rad, wavelength_illumination_um):
Expand Down Expand Up @@ -219,10 +221,52 @@ def birefringence_and_phase(
retardance = radians_to_nanometers(
reconstructed_parameters_3d[0], wavelength_illumination
)
# Load singular system
U = torch.tensor(
np.array(transfer_function_dataset["singular_system_U"])
)
S = torch.tensor(
np.array(transfer_function_dataset["singular_system_S"][0])
)
Vh = torch.tensor(
np.array(transfer_function_dataset["singular_system_Vh"])
)
singular_system = (U, S, Vh)

# Convert retardance and orientation to stokes
stokes = stokes_after_adr(*reconstructed_parameters_3d)

stokes = torch.nan_to_num_(torch.stack(stokes), nan=0.0) # very rare nans from previous like

# Apply reconstruction
joint_recon_params = inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function(
szyx_data=stokes,
singular_system=singular_system,
intensity_to_stokes_matrix=None,
**settings_phase.apply_inverse.dict(),
)

new_ret = (
joint_recon_params[1] ** 2 + joint_recon_params[2] ** 2
) ** (0.5)
new_ori = _s12_to_orientation(
joint_recon_params[1], -joint_recon_params[2]
)

# Convert stokes to retardance and orientation
# new_ret, new_ori, _ = estimate_ar_from_stokes012(*joint_recon_params)

# Convert retardance
new_ret_nm = radians_to_nanometers(new_ret, wavelength_illumination)

# Save
output = torch.stack(
(retardance,) + reconstructed_parameters_3d[1:] + (zyx_phase,)
(retardance,)
+ reconstructed_parameters_3d[1:]
+ (zyx_phase,)
+ (new_ret_nm,)
+ (new_ori,)
+ (joint_recon_params[0],)
)
return output

Expand Down
10 changes: 7 additions & 3 deletions recOrder/cli/apply_inverse_transfer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def get_reconstruction_output_metadata(position_path: Path, config_path: Path):
channel_names.append("Phase2D")
elif recon_dim == 3:
channel_names.append("Phase3D")
if recon_biref and recon_phase:
channel_names.append("Retardance_Joint_Decon")
channel_names.append("Orientation_Joint_Decon")
channel_names.append("Phase_Joint_Decon")
if recon_fluo:
fluor_name = settings.input_channel_names[0]
if recon_dim == 2:
Expand Down Expand Up @@ -313,7 +317,7 @@ def apply_inverse_transfer_function_cli(

settings = utils.yaml_to_model(config_filepath, ReconstructionSettings)
gb_ram_request = 0
gb_per_element = 4 / 2**30 # bytes_per_float32 / bytes_per_gb
gb_per_element = 4 / 2 ** 30 # bytes_per_float32 / bytes_per_gb
voxel_resource_multiplier = 4
fourier_resource_multiplier = 32
input_memory = Z * Y * X * gb_per_element
Expand All @@ -336,13 +340,13 @@ def apply_inverse_transfer_function_cli(
f"{cpu_request} CPU{'s' if cpu_request > 1 else ''} and "
f"{gb_ram_request} GB of memory per CPU."
)
executor = submitit.AutoExecutor(folder="logs")
executor = submitit.AutoExecutor(folder="logs") #, cluster="debug")

executor.update_parameters(
slurm_array_parallelism=np.min([50, num_jobs]),
slurm_mem_per_cpu=f"{gb_ram_request}G",
slurm_cpus_per_task=cpu_request,
slurm_time=60,
slurm_time=600,
slurm_partition="cpu",
# more slurm_*** resource parameters here
)
Expand Down
75 changes: 72 additions & 3 deletions recOrder/cli/compute_transfer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from iohub.ngff import open_ome_zarr, Position
from waveorder.models import (
inplane_oriented_thick_pol3d_vector,
inplane_oriented_thick_pol3d,
isotropic_fluorescent_thick_3d,
isotropic_thin_3d,
Expand All @@ -20,6 +21,70 @@
from recOrder.io import utils


def generate_and_save_vector_birefringence_transfer_function(
settings: ReconstructionSettings, dataset: Position, zyx_shape: tuple
):
"""Generates and saves the vector birefringence transfer function
to the dataset, based on the settings.

Parameters
----------
settings : ReconstructionSettings
dataset : NGFF Node
The dataset that will be updated.
zyx_shape : tuple
A tuple of integers specifying the input data's shape in (Z, Y, X) order
"""
echo_headline(
"Generating vector birefringence transfer function with settings:"
)
echo_settings(settings.birefringence.transfer_function)
echo_settings(settings.phase.transfer_function)

num_elements = np.array(zyx_shape).prod()
max_tf_elements = 1e7 # empirical, based on memory usage
transverse_downsample_factor = np.ceil(np.sqrt(num_elements / max_tf_elements))
echo_headline(f"Downsampling transfer function in X and Y by {transverse_downsample_factor}x")

sfZYX_transfer_function, _, singular_system= (
inplane_oriented_thick_pol3d_vector.calculate_transfer_function(
zyx_shape=zyx_shape,
scheme=str(len(settings.input_channel_names)) + "-State",
**settings.birefringence.transfer_function.dict(),
**settings.phase.transfer_function.dict(),
transverse_downsample_factor=transverse_downsample_factor,
)
)

U, S, Vh = singular_system
chunks = (1, 1, 1, zyx_shape[1], zyx_shape[2])

# Add dummy channels
for i in range(3):
dataset.append_channel(f"ch{i}")

dataset.create_image(
"vector_transfer_function",
sfZYX_transfer_function.cpu().numpy(),
chunks=chunks,
)
dataset.create_image(
"singular_system_U",
U.cpu().numpy(),
chunks=chunks,
)
dataset.create_image(
"singular_system_S",
S[None].cpu().numpy(),
chunks=chunks,
)
dataset.create_image(
"singular_system_Vh",
Vh.cpu().numpy(),
chunks=chunks,
)


def generate_and_save_birefringence_transfer_function(settings, dataset):
"""Generates and saves the birefringence transfer function to the dataset, based on the settings.

Expand All @@ -40,9 +105,9 @@ def generate_and_save_birefringence_transfer_function(settings, dataset):
)
)
# Save
dataset[
"intensity_to_stokes_matrix"
] = intensity_to_stokes_matrix.cpu().numpy()[None, None, None, ...]
dataset["intensity_to_stokes_matrix"] = (
intensity_to_stokes_matrix.cpu().numpy()[None, None, None, ...]
)


def generate_and_save_phase_transfer_function(
Expand Down Expand Up @@ -200,6 +265,10 @@ def compute_transfer_function_cli(
generate_and_save_fluorescence_transfer_function(
settings, output_dataset, zyx_shape
)
if settings.birefringence is not None and settings.phase is not None:
generate_and_save_vector_birefringence_transfer_function(
settings, output_dataset, zyx_shape
)

# Write settings to metadata
output_dataset.zattrs["settings"] = settings.dict()
Expand Down
1 change: 0 additions & 1 deletion recOrder/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from iohub import open_ome_zarr



def add_index_to_path(path: Path):
"""Takes a path to a file or folder and appends the smallest index that does
not already exist in that folder.
Expand Down
13 changes: 8 additions & 5 deletions recOrder/io/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,21 @@ def ret_ori_overlay(
overlay_final = np.zeros_like(retardance)

if cmap == "JCh":
J = ret_
C = np.ones_like(J) * 60
J_MAX = 65
C_MAX = 60

J = (ret_ / ret_max) * J_MAX
C = np.ones_like(J) * C_MAX
C[ret_ < ret_min] = 0
h = ori_

JCh = np.stack((J, C, h), axis=-1)
JCh_rgb = cspace_convert(JCh, "JCh", "sRGB1")
JCh_rgb = cspace_convert(JCh, "JCh", "sRGB255")

JCh_rgb[JCh_rgb < 0] = 0
JCh_rgb[JCh_rgb > 1] = 1
JCh_rgb[JCh_rgb > 255] = 255

overlay_final = JCh_rgb
overlay_final = JCh_rgb.astype(np.uint8)
elif cmap == "HSV":
I_hsv = np.moveaxis(
np.stack(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include_package_data = True
python_requires = >=3.10
setup_requires = setuptools_scm
install_requires =
waveorder==2.2.0rc0
waveorder @ git+https://github.com/mehta-lab/waveorder.git@1cb7d53a135771368e26065d9e427535e0475858
click>=8.0.1
natsort>=7.1.1
colorspacious>=1.1.2
Expand Down
Loading