Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brightness normalisation #25

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 121 additions & 6 deletions brainglobe_stitch/image_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __init__(self, directory: Path):

self.load_mesospim_directory()

self.scale_factors: Optional[npt.NDArray] = None
self.intensity_adjusted: List[bool] = [False] * len(
self.tiles[0].resolution_pyramid
)

def __del__(self):
if self.h5_file:
self.h5_file.close()
Expand Down Expand Up @@ -355,6 +360,25 @@ def read_big_stitcher_transforms(self) -> None:
stitched_position = stitched_translations[tile.id]
tile.position = stitched_position

def reload_resolution_pyramid_level(self, resolution_level: int) -> None:
"""
Reload the data for a given resolution level.

Parameters
----------
resolution_level: int
The resolution level to reload the data for.
"""
if self.h5_file:
for tile in self.tiles:
tile.data_pyramid[resolution_level] = da.from_array(
self.h5_file[
f"t00000/{tile.name}/{resolution_level}/cells"
]
)

self.intensity_adjusted[resolution_level] = False

def calculate_overlaps(self) -> None:
"""
Calculate the overlaps between the tiles in the ImageMosaic.
Expand Down Expand Up @@ -397,9 +421,97 @@ def calculate_overlaps(self) -> None:
)
tile_i.neighbours.append(tile_j.id)

def normalise_intensity(
self, resolution_level: int = 0, percentile: int = 80
) -> None:
"""
Normalise the intensity of the image at a given resolution level.

Parameters
----------
resolution_level: int
The resolution level to normalise the intensity at.
percentile: int
The percentile based on which the normalisation is done.
"""
if self.intensity_adjusted[resolution_level]:
print("Intensity already adjusted at this resolution scale.")
return

if self.scale_factors is None:
# Calculate scale factors on at least resolution level 2
# The tiles are adjusted as the scale factors are calculated
self.calculate_intensity_scale_factors(
max(resolution_level, 2), percentile
)

if self.intensity_adjusted[resolution_level]:
return

assert self.scale_factors is not None

# Adjust the intensity of each tile based on the scale factors
for tile in self.tiles:
if self.scale_factors[tile.id] != 1.0:
tile.data_pyramid[resolution_level] = da.multiply(
tile.data_pyramid[resolution_level],
self.scale_factors[tile.id],
).astype(tile.data_pyramid[resolution_level].dtype)

self.intensity_adjusted[resolution_level] = True

def calculate_intensity_scale_factors(
self, resolution_level: int, percentile: int
):
"""
Calculate the scale factors for normalising the intensity of the image.

Parameters
----------
resolution_level: int
The resolution level to calculate the scale factors at.
percentile: int
The percentile based on which the normalisation is done.
"""
num_tiles = len(self.tiles)
scale_factors = np.ones((num_tiles, num_tiles))

for tile_i in self.tiles:
# Iterate through the neighbours of each tile
print(f"Calculating scale factors for tile {tile_i.id}")
for neighbour_id in tile_i.neighbours:
tile_j = self.tiles[neighbour_id]
overlap = self.overlaps[(tile_i.id, tile_j.id)]

# Extract the overlapping data from both tiles
i_overlap, j_overlap = overlap.extract_tile_overlaps(
resolution_level
)

# Calculate the percentile intensity of the overlapping data
median_i = da.percentile(i_overlap.ravel(), percentile)
median_j = da.percentile(j_overlap.ravel(), percentile)

curr_scale_factor = (median_i / median_j).compute()
scale_factors[tile_i.id][tile_j.id] = curr_scale_factor[0]

# Adjust the tile intensity based on the scale factor
tile_j.data_pyramid[resolution_level] = da.multiply(
tile_j.data_pyramid[resolution_level],
curr_scale_factor,
).astype(tile_j.data_pyramid[resolution_level].dtype)

self.intensity_adjusted[resolution_level] = True
# Calculate the product of the scale factors for each tile's neighbours
# The product is the final scale factor for that tile
self.scale_factors = np.prod(scale_factors, axis=0)

return

def fuse(
self,
output_file_name: str,
output_path: Path,
normalise_intensity: bool = False,
downscale_factors: Tuple[int, int, int] = (1, 2, 2),
chunk_shape: Tuple[int, int, int] = (128, 128, 128),
pyramid_depth: int = 5,
Expand All @@ -411,9 +523,11 @@ def fuse(

Parameters
----------
output_file_name: str
output_path: Path
The name of the output file, suffix dictates the output file type.
Accepts .zarr and .h5 extensions.
Accepts .zarr and .h5 extensions.#
normalise_intensity: bool, default: False
Normalise the intensity differences between tiles.
downscale_factors: Tuple[int, int, int], default: (1, 2, 2)
The factors to downscale the image by in the z, y, x dimensions.
chunk_shape: Tuple[int, ...], default: (128, 128, 128)
Expand All @@ -425,8 +539,6 @@ def fuse(
compression_level: int, default: 6
The compression level to use (only used for zarr).
"""
output_path = self.directory / output_file_name

z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape
# Calculate the shape of the fused image
fused_image_shape: Tuple[int, int, int] = (
Expand All @@ -435,6 +547,9 @@ def fuse(
max([tile.position[2] for tile in self.tiles]) + x_size,
)

if normalise_intensity:
self.normalise_intensity(0, 80)

if output_path.suffix == ".zarr":
self._fuse_to_zarr(
output_path,
Expand Down Expand Up @@ -586,7 +701,7 @@ def _fuse_to_bdv_h5(
output_path: Path,
fused_image_shape: Tuple[int, int, int],
downscale_factors: Tuple[int, int, int],
pyramid_depth,
pyramid_depth: int,
chunk_shape: Tuple[int, int, int],
) -> None:
"""
Expand Down
129 changes: 114 additions & 15 deletions brainglobe_stitch/stitching_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from napari.qt.threading import create_worker
from napari.utils.notifications import show_info, show_warning
from qtpy.QtWidgets import (
QCheckBox,
QComboBox,
QFileDialog,
QFormLayout,
Expand All @@ -22,9 +23,11 @@
QLineEdit,
QProgressBar,
QPushButton,
QSpinBox,
QVBoxLayout,
QWidget,
)
from superqt import QCollapsible

from brainglobe_stitch.file_utils import (
check_mesospim_directory,
Expand Down Expand Up @@ -208,12 +211,67 @@
self.stitch_button.setEnabled(False)
self.layout().addWidget(self.stitch_button)

self.fuse_option_widget = QWidget()
self.fuse_option_widget.setLayout(QFormLayout())
self.output_file_name_field = QLineEdit()
self.adjust_intensity_button = QPushButton("Adjust Intensity")
self.adjust_intensity_button.clicked.connect(
self._on_adjust_intensity_button_clicked
)
self.adjust_intensity_button.setEnabled(False)
self.layout().addWidget(self.adjust_intensity_button)

self.adjust_intensity_collapsible = QCollapsible(
"Intensity Adjustment Options"
)
self.adjust_intensity_menu = QWidget()
self.adjust_intensity_menu.setLayout(
QFormLayout(parent=self.adjust_intensity_menu)
)

self.percentile_field = QSpinBox(parent=self.adjust_intensity_menu)
self.percentile_field.setRange(0, 100)
self.percentile_field.setValue(80)
self.adjust_intensity_menu.layout().addRow(
"Percentile", self.percentile_field
)

self.adjust_intensity_collapsible.setContent(
self.adjust_intensity_menu
)

self.layout().addWidget(self.adjust_intensity_collapsible)
self.adjust_intensity_collapsible.collapse(animate=False)

self.fuse_option_widget = QWidget(parent=self)
self.fuse_option_widget.setLayout(
QFormLayout(parent=self.fuse_option_widget)
)
self.normalise_intensity_toggle = QCheckBox()

self.select_output_path = QWidget()
self.select_output_path.setLayout(QHBoxLayout())

self.select_output_path_text_field = QLineEdit()
self.select_output_path_text_field.setText(str(self.working_directory))
self.select_output_path.layout().addWidget(
self.select_output_path_text_field
)

self.open_file_dialog_output = QPushButton("Browse")
self.open_file_dialog_output.clicked.connect(
self._on_open_file_dialog_output_clicked
)
self.select_output_path.layout().addWidget(
self.open_file_dialog_output
)

self.fuse_option_widget.layout().addWidget(self.select_output_path)

self.fuse_option_widget.layout().addRow(
"Output file name:", self.output_file_name_field
"Normalise intensity:", self.normalise_intensity_toggle
)
self.fuse_option_widget.layout().addRow(QLabel("Output file name:"))
self.fuse_option_widget.layout().addRow(self.select_output_path)

self.layout().addWidget(self.fuse_option_widget)

self.layout().addWidget(self.fuse_option_widget)

Expand Down Expand Up @@ -294,6 +352,7 @@
)
worker.yielded.connect(self._set_tile_layers)
worker.start()
self.adjust_intensity_button.setEnabled(True)

def _set_tile_layers(self, tile_layer: napari.layers.Image) -> None:
"""
Expand Down Expand Up @@ -370,12 +429,20 @@
display_info(self, "Warning", error_message)
return

self.image_mosaic.stitch(
worker = create_worker(
self.image_mosaic.stitch,
self.imagej_path,
resolution_level=2,
selected_channel=self.fuse_channel_dropdown.currentText(),
)

self.fuse_button.setEnabled(False)
self.stitch_button.setEnabled(False)
self.adjust_intensity_button.setEnabled(False)
worker.finished.connect(self._on_stitch_finished)
worker.start()

def _on_stitch_finished(self):
show_info("Stitching complete")

napari_data = self.image_mosaic.data_for_napari(
Expand All @@ -384,9 +451,39 @@

self.update_tiles_from_mosaic(napari_data)
self.fuse_button.setEnabled(True)
self.stitch_button.setEnabled(True)
self.adjust_intensity_button.setEnabled(True)

def _on_adjust_intensity_button_clicked(self):
self.image_mosaic.normalise_intensity(
resolution_level=self.resolution_to_display,
percentile=self.percentile_field.value(),
)

data_for_napari = self.image_mosaic.data_for_napari(
self.resolution_to_display
)

self.update_tiles_from_mosaic(data_for_napari)

def _on_open_file_dialog_output_clicked(self) -> None:
"""
Open a file dialog to select the output file path.
"""
output_file_str = QFileDialog.getSaveFileName(
self, "Select output file", str(self.working_directory)
)[0]
# A blank string is returned if the user cancels the dialog
if not output_file_str:
return

Check warning on line 478 in brainglobe_stitch/stitching_widget.py

View check run for this annotation

Codecov / codecov/patch

brainglobe_stitch/stitching_widget.py#L478

Added line #L478 was not covered by tests

self.select_output_path_text_field.setText(output_file_str)

def _on_fuse_button_clicked(self) -> None:
if not self.output_file_name_field.text():
if (
self.select_output_path_text_field.text()
== str(self.working_directory)
) or (not self.select_output_path_text_field.text()):
error_message = "Output file name not specified"
show_warning(error_message)
display_info(self, "Warning", error_message)
Expand All @@ -398,10 +495,10 @@
display_info(self, "Warning", error_message)
return

path = self.working_directory / self.output_file_name_field.text()
output_path = Path(self.select_output_path_text_field.text())
valid_extensions = [".zarr", ".h5"]

if path.suffix not in valid_extensions:
if output_path.suffix not in valid_extensions:
error_message = (
f"Output file name should end with "
f"{', '.join(valid_extensions)}"
Expand All @@ -410,15 +507,16 @@
display_info(self, "Warning", error_message)
return

if path.exists():
if output_path.exists():
error_message = (
f"Output file {path} already exists. Replace existing file?"
f"Output file {output_path} already exists. "
f"Replace existing file?"
)
if display_warning(self, "Warning", error_message):
(
shutil.rmtree(path)
if path.suffix == ".zarr"
else path.unlink()
shutil.rmtree(output_path)
if output_path.suffix == ".zarr"
else output_path.unlink()
)
else:
show_warning(
Expand All @@ -428,11 +526,12 @@
return

self.image_mosaic.fuse(
self.output_file_name_field.text(),
output_path,
normalise_intensity=self.normalise_intensity_toggle.isChecked(),
)

show_info("Fusing complete")
display_info(self, "Info", f"Fused image saved to {path}")
display_info(self, "Info", f"Fused image saved to {output_path}")

def check_imagej_path(self) -> None:
"""
Expand Down
Loading
Loading