diff --git a/brainglobe_stitch/image_mosaic.py b/brainglobe_stitch/image_mosaic.py index 99cdc79..80da55d 100644 --- a/brainglobe_stitch/image_mosaic.py +++ b/brainglobe_stitch/image_mosaic.py @@ -80,6 +80,11 @@ def __init__(self, directory: Path): self.load_mesospim_directory() + self.scale_factors: Optional[npt.NDArray] = None + self.intensity_adjusted: List[bool] = [False] * len( + self.tiles[0].resolution_pyramid + ) + def __del__(self): if self.h5_file: self.h5_file.close() @@ -355,6 +360,25 @@ def read_big_stitcher_transforms(self) -> None: stitched_position = stitched_translations[tile.id] tile.position = stitched_position + def reload_resolution_pyramid_level(self, resolution_level: int) -> None: + """ + Reload the data for a given resolution level. + + Parameters + ---------- + resolution_level: int + The resolution level to reload the data for. + """ + if self.h5_file: + for tile in self.tiles: + tile.data_pyramid[resolution_level] = da.from_array( + self.h5_file[ + f"t00000/{tile.name}/{resolution_level}/cells" + ] + ) + + self.intensity_adjusted[resolution_level] = False + def calculate_overlaps(self) -> None: """ Calculate the overlaps between the tiles in the ImageMosaic. @@ -397,9 +421,97 @@ def calculate_overlaps(self) -> None: ) tile_i.neighbours.append(tile_j.id) + def normalise_intensity( + self, resolution_level: int = 0, percentile: int = 80 + ) -> None: + """ + Normalise the intensity of the image at a given resolution level. + + Parameters + ---------- + resolution_level: int + The resolution level to normalise the intensity at. + percentile: int + The percentile based on which the normalisation is done. + """ + if self.intensity_adjusted[resolution_level]: + print("Intensity already adjusted at this resolution scale.") + return + + if self.scale_factors is None: + # Calculate scale factors on at least resolution level 2 + # The tiles are adjusted as the scale factors are calculated + self.calculate_intensity_scale_factors( + max(resolution_level, 2), percentile + ) + + if self.intensity_adjusted[resolution_level]: + return + + assert self.scale_factors is not None + + # Adjust the intensity of each tile based on the scale factors + for tile in self.tiles: + if self.scale_factors[tile.id] != 1.0: + tile.data_pyramid[resolution_level] = da.multiply( + tile.data_pyramid[resolution_level], + self.scale_factors[tile.id], + ).astype(tile.data_pyramid[resolution_level].dtype) + + self.intensity_adjusted[resolution_level] = True + + def calculate_intensity_scale_factors( + self, resolution_level: int, percentile: int + ): + """ + Calculate the scale factors for normalising the intensity of the image. + + Parameters + ---------- + resolution_level: int + The resolution level to calculate the scale factors at. + percentile: int + The percentile based on which the normalisation is done. + """ + num_tiles = len(self.tiles) + scale_factors = np.ones((num_tiles, num_tiles)) + + for tile_i in self.tiles: + # Iterate through the neighbours of each tile + print(f"Calculating scale factors for tile {tile_i.id}") + for neighbour_id in tile_i.neighbours: + tile_j = self.tiles[neighbour_id] + overlap = self.overlaps[(tile_i.id, tile_j.id)] + + # Extract the overlapping data from both tiles + i_overlap, j_overlap = overlap.extract_tile_overlaps( + resolution_level + ) + + # Calculate the percentile intensity of the overlapping data + median_i = da.percentile(i_overlap.ravel(), percentile) + median_j = da.percentile(j_overlap.ravel(), percentile) + + curr_scale_factor = (median_i / median_j).compute() + scale_factors[tile_i.id][tile_j.id] = curr_scale_factor[0] + + # Adjust the tile intensity based on the scale factor + tile_j.data_pyramid[resolution_level] = da.multiply( + tile_j.data_pyramid[resolution_level], + curr_scale_factor, + ).astype(tile_j.data_pyramid[resolution_level].dtype) + + self.intensity_adjusted[resolution_level] = True + # Calculate the product of the scale factors for each tile's neighbours + # The product is the final scale factor for that tile + self.scale_factors = np.prod(scale_factors, axis=0) + + return + def fuse( self, - output_file_name: str, + output_path: Path, + normalise_intensity: bool = False, downscale_factors: Tuple[int, int, int] = (1, 2, 2), chunk_shape: Tuple[int, int, int] = (128, 128, 128), pyramid_depth: int = 5, @@ -411,9 +523,11 @@ def fuse( Parameters ---------- - output_file_name: str + output_path: Path The name of the output file, suffix dictates the output file type. - Accepts .zarr and .h5 extensions. + Accepts .zarr and .h5 extensions.# + normalise_intensity: bool, default: False + Normalise the intensity differences between tiles. downscale_factors: Tuple[int, int, int], default: (1, 2, 2) The factors to downscale the image by in the z, y, x dimensions. chunk_shape: Tuple[int, ...], default: (128, 128, 128) @@ -425,8 +539,6 @@ def fuse( compression_level: int, default: 6 The compression level to use (only used for zarr). """ - output_path = self.directory / output_file_name - z_size, y_size, x_size = self.tiles[0].data_pyramid[0].shape # Calculate the shape of the fused image fused_image_shape: Tuple[int, int, int] = ( @@ -435,6 +547,9 @@ def fuse( max([tile.position[2] for tile in self.tiles]) + x_size, ) + if normalise_intensity: + self.normalise_intensity(0, 80) + if output_path.suffix == ".zarr": self._fuse_to_zarr( output_path, @@ -586,7 +701,7 @@ def _fuse_to_bdv_h5( output_path: Path, fused_image_shape: Tuple[int, int, int], downscale_factors: Tuple[int, int, int], - pyramid_depth, + pyramid_depth: int, chunk_shape: Tuple[int, int, int], ) -> None: """ diff --git a/brainglobe_stitch/stitching_widget.py b/brainglobe_stitch/stitching_widget.py index 64f4764..5fca2e4 100644 --- a/brainglobe_stitch/stitching_widget.py +++ b/brainglobe_stitch/stitching_widget.py @@ -14,6 +14,7 @@ from napari.qt.threading import create_worker from napari.utils.notifications import show_info, show_warning from qtpy.QtWidgets import ( + QCheckBox, QComboBox, QFileDialog, QFormLayout, @@ -22,9 +23,11 @@ QLineEdit, QProgressBar, QPushButton, + QSpinBox, QVBoxLayout, QWidget, ) +from superqt import QCollapsible from brainglobe_stitch.file_utils import ( check_mesospim_directory, @@ -208,12 +211,67 @@ def __init__(self, napari_viewer: Viewer): self.stitch_button.setEnabled(False) self.layout().addWidget(self.stitch_button) - self.fuse_option_widget = QWidget() - self.fuse_option_widget.setLayout(QFormLayout()) - self.output_file_name_field = QLineEdit() + self.adjust_intensity_button = QPushButton("Adjust Intensity") + self.adjust_intensity_button.clicked.connect( + self._on_adjust_intensity_button_clicked + ) + self.adjust_intensity_button.setEnabled(False) + self.layout().addWidget(self.adjust_intensity_button) + + self.adjust_intensity_collapsible = QCollapsible( + "Intensity Adjustment Options" + ) + self.adjust_intensity_menu = QWidget() + self.adjust_intensity_menu.setLayout( + QFormLayout(parent=self.adjust_intensity_menu) + ) + + self.percentile_field = QSpinBox(parent=self.adjust_intensity_menu) + self.percentile_field.setRange(0, 100) + self.percentile_field.setValue(80) + self.adjust_intensity_menu.layout().addRow( + "Percentile", self.percentile_field + ) + + self.adjust_intensity_collapsible.setContent( + self.adjust_intensity_menu + ) + + self.layout().addWidget(self.adjust_intensity_collapsible) + self.adjust_intensity_collapsible.collapse(animate=False) + + self.fuse_option_widget = QWidget(parent=self) + self.fuse_option_widget.setLayout( + QFormLayout(parent=self.fuse_option_widget) + ) + self.normalise_intensity_toggle = QCheckBox() + + self.select_output_path = QWidget() + self.select_output_path.setLayout(QHBoxLayout()) + + self.select_output_path_text_field = QLineEdit() + self.select_output_path_text_field.setText(str(self.working_directory)) + self.select_output_path.layout().addWidget( + self.select_output_path_text_field + ) + + self.open_file_dialog_output = QPushButton("Browse") + self.open_file_dialog_output.clicked.connect( + self._on_open_file_dialog_output_clicked + ) + self.select_output_path.layout().addWidget( + self.open_file_dialog_output + ) + + self.fuse_option_widget.layout().addWidget(self.select_output_path) + self.fuse_option_widget.layout().addRow( - "Output file name:", self.output_file_name_field + "Normalise intensity:", self.normalise_intensity_toggle ) + self.fuse_option_widget.layout().addRow(QLabel("Output file name:")) + self.fuse_option_widget.layout().addRow(self.select_output_path) + + self.layout().addWidget(self.fuse_option_widget) self.layout().addWidget(self.fuse_option_widget) @@ -294,6 +352,7 @@ def _on_add_tiles_button_clicked(self) -> None: ) worker.yielded.connect(self._set_tile_layers) worker.start() + self.adjust_intensity_button.setEnabled(True) def _set_tile_layers(self, tile_layer: napari.layers.Image) -> None: """ @@ -370,12 +429,20 @@ def _on_stitch_button_clicked(self) -> None: display_info(self, "Warning", error_message) return - self.image_mosaic.stitch( + worker = create_worker( + self.image_mosaic.stitch, self.imagej_path, resolution_level=2, selected_channel=self.fuse_channel_dropdown.currentText(), ) + self.fuse_button.setEnabled(False) + self.stitch_button.setEnabled(False) + self.adjust_intensity_button.setEnabled(False) + worker.finished.connect(self._on_stitch_finished) + worker.start() + + def _on_stitch_finished(self): show_info("Stitching complete") napari_data = self.image_mosaic.data_for_napari( @@ -384,9 +451,39 @@ def _on_stitch_button_clicked(self) -> None: self.update_tiles_from_mosaic(napari_data) self.fuse_button.setEnabled(True) + self.stitch_button.setEnabled(True) + self.adjust_intensity_button.setEnabled(True) + + def _on_adjust_intensity_button_clicked(self): + self.image_mosaic.normalise_intensity( + resolution_level=self.resolution_to_display, + percentile=self.percentile_field.value(), + ) + + data_for_napari = self.image_mosaic.data_for_napari( + self.resolution_to_display + ) + + self.update_tiles_from_mosaic(data_for_napari) + + def _on_open_file_dialog_output_clicked(self) -> None: + """ + Open a file dialog to select the output file path. + """ + output_file_str = QFileDialog.getSaveFileName( + self, "Select output file", str(self.working_directory) + )[0] + # A blank string is returned if the user cancels the dialog + if not output_file_str: + return + + self.select_output_path_text_field.setText(output_file_str) def _on_fuse_button_clicked(self) -> None: - if not self.output_file_name_field.text(): + if ( + self.select_output_path_text_field.text() + == str(self.working_directory) + ) or (not self.select_output_path_text_field.text()): error_message = "Output file name not specified" show_warning(error_message) display_info(self, "Warning", error_message) @@ -398,10 +495,10 @@ def _on_fuse_button_clicked(self) -> None: display_info(self, "Warning", error_message) return - path = self.working_directory / self.output_file_name_field.text() + output_path = Path(self.select_output_path_text_field.text()) valid_extensions = [".zarr", ".h5"] - if path.suffix not in valid_extensions: + if output_path.suffix not in valid_extensions: error_message = ( f"Output file name should end with " f"{', '.join(valid_extensions)}" @@ -410,15 +507,16 @@ def _on_fuse_button_clicked(self) -> None: display_info(self, "Warning", error_message) return - if path.exists(): + if output_path.exists(): error_message = ( - f"Output file {path} already exists. Replace existing file?" + f"Output file {output_path} already exists. " + f"Replace existing file?" ) if display_warning(self, "Warning", error_message): ( - shutil.rmtree(path) - if path.suffix == ".zarr" - else path.unlink() + shutil.rmtree(output_path) + if output_path.suffix == ".zarr" + else output_path.unlink() ) else: show_warning( @@ -428,11 +526,12 @@ def _on_fuse_button_clicked(self) -> None: return self.image_mosaic.fuse( - self.output_file_name_field.text(), + output_path, + normalise_intensity=self.normalise_intensity_toggle.isChecked(), ) show_info("Fusing complete") - display_info(self, "Info", f"Fused image saved to {path}") + display_info(self, "Info", f"Fused image saved to {output_path}") def check_imagej_path(self) -> None: """ diff --git a/tests/test_unit/conftest.py b/tests/test_unit/conftest.py index 96992c2..b4b411c 100644 --- a/tests/test_unit/conftest.py +++ b/tests/test_unit/conftest.py @@ -211,6 +211,16 @@ def test_constants(imagej_path): [6, 7, 118], [5, 123, 116], ], + "EXPECTED_INTENSITY_FACTORS": [ + 1.00000, + 0.99636, + 1.00000, + 1.04878, + 0.58846, + 0.55362, + 1.06679, + 1.08642, + ], "EXPECTED_FUSED_SHAPE": (113, 251, 246), "CHANNELS": ["561 nm", "647 nm"], "PIXEL_SIZE_XY": 4.08, diff --git a/tests/test_unit/test_image_mosaic.py b/tests/test_unit/test_image_mosaic.py index 5135f9d..33216dd 100644 --- a/tests/test_unit/test_image_mosaic.py +++ b/tests/test_unit/test_image_mosaic.py @@ -109,20 +109,153 @@ def test_data_for_napari(image_mosaic, test_constants): assert (tile_data[1] == expected_pos).all() +@pytest.mark.parametrize( + "resolution_level", + [0, 1], +) +def test_normalise_intensity( + mocker, test_constants, image_mosaic, resolution_level +): + def force_set_scale_factors(*args, **kwargs): + image_mosaic.scale_factors = test_constants[ + "EXPECTED_INTENSITY_FACTORS" + ] + image_mosaic.intensity_adjusted[args[0]] = True + + mocker.patch( + "brainglobe_stitch.image_mosaic.ImageMosaic.calculate_intensity_scale_factors", + side_effect=force_set_scale_factors, + ) + + image_mosaic.reload_resolution_pyramid_level(resolution_level) + assert not image_mosaic.intensity_adjusted[resolution_level] + image_mosaic.scale_factors = None + + image_mosaic.normalise_intensity(resolution_level) + assert image_mosaic.intensity_adjusted[resolution_level] + assert len(image_mosaic.scale_factors) == test_constants["NUM_TILES"] + + for i in range(test_constants["NUM_TILES"]): + # Check that there each tile has the correct number of pending tasks + # in the dask graph + # Expect to have 4: 2 for loading the data, 2 for scaling the data + # Since the resolution levels are less than 2, the scaling factors are + # calculated on a different resolution level and then applied to the + # current resolution level + if test_constants["EXPECTED_INTENSITY_FACTORS"][i] != 1.0: + assert ( + len( + image_mosaic.tiles[i] + .data_pyramid[resolution_level] + .dask.layers + ) + == 4 + ) + + +@pytest.mark.parametrize( + "resolution_level", + [2, 3, 4], +) +def test_normalise_intensity_done_with_factors( + mocker, image_mosaic, resolution_level, test_constants +): + def force_set_scale_factors(*args, **kwargs): + image_mosaic.scale_factors = test_constants[ + "EXPECTED_INTENSITY_FACTORS" + ] + image_mosaic.intensity_adjusted[args[0]] = True + + mock_calc_intensity_factors = mocker.patch( + "brainglobe_stitch.image_mosaic.ImageMosaic.calculate_intensity_scale_factors", + side_effect=force_set_scale_factors, + ) + + image_mosaic.reload_resolution_pyramid_level(resolution_level) + assert not image_mosaic.intensity_adjusted[resolution_level] + image_mosaic.scale_factors = None + + image_mosaic.normalise_intensity(resolution_level) + assert image_mosaic.intensity_adjusted[resolution_level] + assert len(image_mosaic.scale_factors) == test_constants["NUM_TILES"] + + mock_calc_intensity_factors.assert_called_once_with(resolution_level, 80) + + # Check that no scale adjustment calculations are queued for the tiles + # at the specified resolution level as the correction factors were + # calculated based on this resolution level, + # therefore no calculations queued. + for i in range(test_constants["NUM_TILES"]): + assert ( + len( + image_mosaic.tiles[i] + .data_pyramid[resolution_level] + .dask.layers + ) + == 2 + ) + + +@pytest.mark.parametrize( + "resolution_level", + [0, 1, 2, 3, 4], +) +def test_normalise_intensity_already_adjusted( + image_mosaic, resolution_level, test_constants +): + image_mosaic.reload_resolution_pyramid_level(resolution_level) + image_mosaic.intensity_adjusted[resolution_level] = True + image_mosaic.normalise_intensity(resolution_level) + + assert image_mosaic.intensity_adjusted[resolution_level] + + # Check that no scale adjustment calculations are queued for the tiles + # at the specified resolution level + for i in range(test_constants["NUM_TILES"]): + assert ( + len( + image_mosaic.tiles[i] + .data_pyramid[resolution_level] + .dask.layers + ) + == 2 + ) + + +def test_calculate_intensity_scale_factors(image_mosaic, test_constants): + resolution_level = 2 + percentile = 50 + image_mosaic.reload_resolution_pyramid_level(resolution_level) + image_mosaic.scale_factors = None + + image_mosaic.calculate_intensity_scale_factors( + resolution_level, percentile + ) + + assert len(image_mosaic.scale_factors) == test_constants["NUM_TILES"] + # Check the relative tolerance + assert np.allclose( + image_mosaic.scale_factors, + test_constants["EXPECTED_INTENSITY_FACTORS"], + rtol=1e-2, + ) + + def test_fuse_invalid_file_type(image_mosaic): + output_file = image_mosaic.xml_path.parent / "fused.txt" with pytest.raises(ValueError): - image_mosaic.fuse("fused.txt") + image_mosaic.fuse(output_file) def test_fuse_bdv_h5_defaults(image_mosaic, mocker, test_constants): mock_fuse_function = mocker.patch( "brainglobe_stitch.image_mosaic.ImageMosaic._fuse_to_bdv_h5", ) - file_name = "fused.h5" + file_path = image_mosaic.xml_path.parent / "fused.h5" - image_mosaic.fuse(file_name) + image_mosaic.fuse(file_path) mock_fuse_function.assert_called_once_with( - image_mosaic.xml_path.parent / file_name, + file_path, test_constants["EXPECTED_FUSED_SHAPE"], test_constants["DEFAULT_DOWNSAMPLE_FACTORS"], test_constants["DEFAULT_PYRAMID_DEPTH"], @@ -130,6 +263,31 @@ def test_fuse_bdv_h5_defaults(image_mosaic, mocker, test_constants): ) +def test_fuse_zarr_normalise_intensity(image_mosaic, mocker, test_constants): + file_path = image_mosaic.xml_path.parent / "fused.zarr" + + mock_fuse_to_zarr = mocker.patch( + "brainglobe_stitch.image_mosaic.ImageMosaic._fuse_to_zarr" + ) + mock_normalise_intensity = mocker.patch( + "brainglobe_stitch.image_mosaic.ImageMosaic.normalise_intensity" + ) + + image_mosaic.normalise_intensity = mock_normalise_intensity + image_mosaic.fuse(file_path, normalise_intensity=True) + + mock_normalise_intensity.assert_called_once_with(0, 80) + mock_fuse_to_zarr.assert_called_once_with( + file_path, + test_constants["EXPECTED_FUSED_SHAPE"], + test_constants["DEFAULT_DOWNSAMPLE_FACTORS"], + test_constants["DEFAULT_PYRAMID_DEPTH"], + test_constants["DEFAULT_CHUNK_SHAPE"], + test_constants["DEFAULT_COMPRESSION_METHOD"], + test_constants["DEFAULT_COMPRESSION_LEVEL"], + ) + + @pytest.mark.parametrize( "downscale_factors, chunk_shape, pyramid_depth", [((2, 2, 2), (64, 64, 64), 2), ((4, 4, 4), (32, 32, 32), 3)], @@ -145,11 +303,18 @@ def test_fuse_bdv_h5_custom( mock_fuse_function = mocker.patch( "brainglobe_stitch.image_mosaic.ImageMosaic._fuse_to_bdv_h5", ) - file_name = "fused.h5" + file_path = image_mosaic.xml_path.parent / "fused.h5" - image_mosaic.fuse(file_name, downscale_factors, chunk_shape, pyramid_depth) + normalise_intensity = False + image_mosaic.fuse( + file_path, + normalise_intensity, + downscale_factors, + chunk_shape, + pyramid_depth, + ) mock_fuse_function.assert_called_once_with( - image_mosaic.xml_path.parent / file_name, + file_path, test_constants["EXPECTED_FUSED_SHAPE"], downscale_factors, pyramid_depth, @@ -158,16 +323,16 @@ def test_fuse_bdv_h5_custom( def test_fuse_zarr_file(image_mosaic, mocker, test_constants): - file_name = "fused.zarr" + file_path = image_mosaic.xml_path.parent / "fused.zarr" mock_fuse_to_zarr = mocker.patch( "brainglobe_stitch.image_mosaic.ImageMosaic._fuse_to_zarr" ) - image_mosaic.fuse(file_name) + image_mosaic.fuse(file_path) mock_fuse_to_zarr.assert_called_once_with( - image_mosaic.xml_path.parent / file_name, + file_path, test_constants["EXPECTED_FUSED_SHAPE"], test_constants["DEFAULT_DOWNSAMPLE_FACTORS"], test_constants["DEFAULT_PYRAMID_DEPTH"], @@ -198,10 +363,12 @@ def test_fuse_bdv_zarr_custom( mock_fuse_function = mocker.patch( "brainglobe_stitch.image_mosaic.ImageMosaic._fuse_to_zarr", ) - file_name = "fused.zarr" + file_path = image_mosaic.xml_path.parent / "fused.zarr" + normalise_intensity = False image_mosaic.fuse( - file_name, + file_path, + normalise_intensity, downscale_factors, chunk_shape, pyramid_depth, @@ -209,7 +376,7 @@ def test_fuse_bdv_zarr_custom( compression_level, ) mock_fuse_function.assert_called_once_with( - image_mosaic.xml_path.parent / file_name, + file_path, test_constants["EXPECTED_FUSED_SHAPE"], downscale_factors, pyramid_depth, diff --git a/tests/test_unit/test_stitching_widget.py b/tests/test_unit/test_stitching_widget.py index a475ca7..c106d86 100644 --- a/tests/test_unit/test_stitching_widget.py +++ b/tests/test_unit/test_stitching_widget.py @@ -376,21 +376,111 @@ def test_on_stitch_button_clicked( stitching_widget = stitching_widget_with_mosaic stitching_widget.imagej_path = test_constants["MOCK_IMAGEJ_EXEC_PATH"] - mock_stitch_function = mocker.patch( - "brainglobe_stitch.stitching_widget.ImageMosaic.stitch", + mock_create_worker = mocker.patch( + "brainglobe_stitch.stitching_widget.create_worker", autospec=True, ) stitching_widget._on_stitch_button_clicked() - mock_stitch_function.assert_called_once_with( - stitching_widget.image_mosaic, + mock_create_worker.assert_called_once_with( + stitching_widget_with_mosaic.image_mosaic.stitch, stitching_widget.imagej_path, resolution_level=2, selected_channel="", ) +def test_on_stitch_button_clicked_no_image_mosaic( + stitching_widget, test_constants, mocker +): + """ + Tests that the _on_stitch_button_clicked method correctly shows a warning + message to the user when the ImageMosaic object is not set. + """ + mock_show_warning = mocker.patch( + "brainglobe_stitch.stitching_widget.show_warning" + ) + mock_display_info = mocker.patch( + "brainglobe_stitch.stitching_widget.display_info", + autospec=True, + ) + error_message = "Open a mesoSPIM directory prior to stitching" + + stitching_widget._on_stitch_button_clicked() + + mock_show_warning.assert_called_once_with(error_message) + mock_display_info.assert_called_once_with( + stitching_widget, "Warning", error_message + ) + + +def test_on_stitch_button_clicked_no_imagej( + stitching_widget_with_mosaic, test_constants, mocker +): + """ + Tests that the _on_stitch_button_clicked method correctly shows a warning + message to the user when the imageJ path is not set. + """ + mock_show_warning = mocker.patch( + "brainglobe_stitch.stitching_widget.show_warning" + ) + mock_display_info = mocker.patch( + "brainglobe_stitch.stitching_widget.display_info", + autospec=True, + ) + error_message = "Select the ImageJ path prior to stitching" + + stitching_widget_with_mosaic._on_stitch_button_clicked() + + mock_show_warning.assert_called_once_with(error_message) + mock_display_info.assert_called_once_with( + stitching_widget_with_mosaic, "Warning", error_message + ) + + +def test_on_stitch_finished(stitching_widget_with_mosaic, mocker): + """ + Tests that the _on_stitch_finished method correctly sets the image_mosaic + attribute of the StitchingWidget to None and enables the create_pyramid + button. + """ + mock_show_info = mocker.patch( + "brainglobe_stitch.stitching_widget.show_info" + ) + stitching_widget = stitching_widget_with_mosaic + + stitching_widget._on_stitch_finished() + + mock_show_info.assert_called_once_with("Stitching complete") + + assert stitching_widget_with_mosaic.fuse_button.isEnabled() + assert stitching_widget_with_mosaic.stitch_button.isEnabled() + assert stitching_widget_with_mosaic.adjust_intensity_button.isEnabled() + + +def tests_on_adjust_intensity_button_clicked( + stitching_widget_with_mosaic, mocker +): + """ + Tests that the _on_adjust_intensity_button_clicked method correctly calls + the adjust_intensity method of the ImageMosaic object with the correct + arguments. + """ + mock_normalise_intensity = mocker.patch( + "brainglobe_stitch.stitching_widget.ImageMosaic.normalise_intensity", + autospec=True, + ) + + stitching_widget_with_mosaic._on_adjust_intensity_button_clicked() + + mock_normalise_intensity.assert_called_once_with( + stitching_widget_with_mosaic.image_mosaic, + resolution_level=3, + percentile=80, + ) + + def test_check_imagej_path_valid(stitching_widget): """ Creates a mock imageJ file in the home directory and sets it as the @@ -469,6 +559,42 @@ def test_update_tiles_from_mosaic( assert (tile.translate == test_data[1]).all() +def test_on_open_dialog_output_clicked(stitching_widget, mocker): + """ + Test that the on_open_dialog_output_clicked method. + The directory is provided by mocking the return of the + QFileDialog.getExistingDirectory method. + """ + test_dir = str(Path.home() / "test_dir") + mocker.patch( + "brainglobe_stitch.stitching_widget.QFileDialog.getSaveFileName", + return_value=[test_dir], + ) + + stitching_widget._on_open_file_dialog_output_clicked() + + assert stitching_widget.select_output_path_text_field.text() == test_dir + + +def test_on_open_dialog_output_clicked_cancelled(stitching_widget, mocker): + """ + Mocks the QFileDialog.getExistingDirectory method to return an empty string + to mimic the user cancelling the file dialog. + The select_output_path_text_field should retain its original value. + """ + original_value = stitching_widget.select_output_path_text_field.text() + mocker.patch( + "brainglobe_stitch.stitching_widget.QFileDialog.getExistingDirectory", + return_value="", + ) + + stitching_widget._on_open_file_dialog_clicked() + + assert ( + stitching_widget.select_output_path_text_field.text() == original_value + ) + + @pytest.mark.parametrize("file_name", ["fused_image.h5", "fused_image.zarr"]) def test_on_fuse_button_clicked( stitching_widget_with_mosaic, mocker, file_name @@ -484,11 +610,14 @@ def test_on_fuse_button_clicked( autospec=True, ) - stitching_widget.output_file_name_field.setText(file_name) + output_path = stitching_widget.working_directory / file_name + stitching_widget.select_output_path_text_field.setText(str(output_path)) stitching_widget._on_fuse_button_clicked() - mock_fuse.assert_called_once_with(stitching_widget.image_mosaic, file_name) + mock_fuse.assert_called_once_with( + stitching_widget.image_mosaic, output_path, normalise_intensity=False + ) mock_display_info.assert_called_once_with( stitching_widget, "Info", @@ -527,7 +656,7 @@ def test_on_fuse_button_clicked_wrong_suffix( ): stitching_widget = stitching_widget_with_mosaic - stitching_widget.output_file_name_field.setText("fused_image.tif") + stitching_widget.select_output_path_text_field.setText("fused_image.tif") error_message = "Output file name should end with .zarr, .h5" mock_show_warning = mocker.patch(