From 41fbece2d28982e01c2cdfb9c91b014e0bbf0742 Mon Sep 17 00:00:00 2001
From: carlosuc3m <100329787@alumnos.uc3m.es>
Date: Thu, 16 Jan 2025 14:55:57 +0100
Subject: [PATCH] keeep iterating to manage to get a stable stardist
---
.../modelrunner/model/Stardist2D.java | 2 +-
.../modelrunner/model/Stardist3D.java | 272 ++-----------
.../modelrunner/model/Stardist3D_old.java | 373 ++++++++++++++++++
.../modelrunner/model/StardistAbstract.java | 131 +++---
4 files changed, 477 insertions(+), 301 deletions(-)
create mode 100644 src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java
diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
index 7c9c6020..e656a5cb 100644
--- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
+++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
@@ -51,7 +51,7 @@
*/
public class Stardist2D extends StardistAbstract {
- private static String MODULE_NAME = "Stardist2D";
+ private static String MODULE_NAME = "StarDist2D";
private Stardist2D(String modelName, String baseDir) throws IOException, ModelSpecsException {
super(modelName, baseDir);
diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java
index c1e5675f..28b429a9 100644
--- a/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java
+++ b/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java
@@ -22,140 +22,85 @@
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
-import java.io.InputStream;
import java.net.URISyntaxException;
-import java.nio.file.Files;
-import java.nio.file.Paths;
-import java.nio.file.StandardCopyOption;
-import java.text.SimpleDateFormat;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Calendar;
-import java.util.LinkedHashMap;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
import org.apache.commons.compress.archivers.ArchiveException;
-import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
-import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
-import io.bioimage.modelrunner.model.processing.Processing;
-import io.bioimage.modelrunner.runmode.RunMode;
-import io.bioimage.modelrunner.runmode.ops.GenericOp;
-import io.bioimage.modelrunner.tensor.Tensor;
-import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.Constants;
-import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
-import net.imglib2.util.Cast;
-import net.imglib2.view.Views;
/**
- * This class provides an implementation of an API to run Stardist 3D models, including pre-
- * and post-processing, with minimal configuration. The class allows for loading pretrained
- * models from Bioimage.io and also includes methods for running predictions using 3D images.
- *
- *
Stardist is a deep learning model specialized in detecting object shapes, particularly
- * star-convex shapes in images, which is especially useful in biomedical imaging applications
- * such as cell nuclei detection.
- *
- * The Stardist3D class includes methods for installing the necessary requirements (such as
- * Python environments), running predictions, and handling Stardist-specific post-processing.
- *
- * TODO: add support for fine-tuning models and Mac ARM processors.
- *
- * Example usage:
- *
- * {@code
- * Stardist3D.installRequirements();
- * Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
- * RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
- * RandomAccessibleInterval res = model.predict(img);
- * }
- *
- *
- * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptor
- * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory
- * @see io.bioimage.modelrunner.exceptions.LoadModelException
- * @see io.bioimage.modelrunner.exceptions.RunModelException
+ * Implementation of an API to run Stardist 3D models out of the box with little configuration.
*
*TODO add fine tuning
- *TODO add support for Mac arm
*
*@author Carlos Garcia
*/
-public class Stardist3D {
-
- private ModelDescriptor descriptor;
-
- private final int channels;
-
- private final float nms_threshold;
+public class Stardist3D extends StardistAbstract {
- private final float prob_threshold;
+ private static String MODULE_NAME = "StarDist3D";
- private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});
-
- private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});
-
- private static final String STARDIST3D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/";
-
- private static final String STARDIST3D_SCRIPT_NAME= "stardist_postprocessing_3D.py";
-
- private static final String STARDIST3D_METHOD_NAME= "stardist_postprocessing";
-
- private Stardist3D() {
- this.channels = 1;
- // TODO get from config??
- this.nms_threshold = 0;
- this.prob_threshold = 0;
+ private Stardist3D(String modelName, String baseDir) throws IOException, ModelSpecsException {
+ super(modelName, baseDir);
}
- private Stardist3D(ModelDescriptor descriptor) {
- this.descriptor = descriptor;
- Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist");
- Map stardistConfig = (Map) stardistMap.get("config");
- Map stardistThres = (Map) stardistMap.get("thresholds");
- this.channels = (int) stardistConfig.get("n_channel_in");;
- this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue();
- this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue();
+ private Stardist3D(ModelDescriptor descriptor) throws IOException, ModelSpecsException {
+ super(descriptor);
+ }
+
+ @Override
+ protected String createImportsCode() {
+ return String.format(LOAD_MODEL_CODE_ABSTRACT, MODULE_NAME, MODULE_NAME,
+ MODULE_NAME, MODULE_NAME, MODULE_NAME, this.name, this.basedir);
+ }
+
+ @Override
+ protected & NativeType> void checkInput(RandomAccessibleInterval image) {
+ if (image.dimensionsAsLongArray().length == 3 && this.nChannels != 1)
+ throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
+ else if (image.dimensionsAsLongArray().length != 4 && this.nChannels != 1)
+ throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
+ else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != nChannels)
+ throw new IllegalArgumentException("This Stardist3D model requires " + nChannels + " channels.");
+ else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2)
+ throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ.");
}
/**
- * Initialize a Stardist3D using the format of the Bioiamge.io model zoo.
+ * Initialize a Stardist2D using the format of the Bioiamge.io model zoo.
* @param modelPath
* path to the Bioimage.io model
- * @return an instance of a Stardist3D model ready to be used
- * @throws ModelSpecsException if the model configuration is incorrect in the specs file (rdf.yaml).
- * @throws FileNotFoundException if the model file is not found in the specified path.
- * @throws IOException if there is an issue reading the model file.
+ * @return an instance of a Stardist2D model ready to be used
+ * @throws ModelSpecsException If there is any error in the configuration of the specs rdf.yaml file of the Bioimage.io.
+ * @throws FileNotFoundException If the model file is not found.
+ * @throws IOException If there's an I/O error.
*/
public static Stardist3D fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException {
ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME);
return new Stardist3D(descriptor);
}
-
+
/**
- * Initialize one of the "official" pretrained Stardist 3D models.
+ * Initialize one of the "official" pretrained Stardist 2D models.
* By default, the model will be installed in the "models" folder inside the application
* @param pretrainedModel
* the name of the pretrained model.
* @param forceInstall
* whether to force the installation or to try to look if the model has already been installed before
- * @return an instance of a pretrained Stardist3D model ready to be used
+ * @return an instance of a pretrained Stardist2D model ready to be used
* @throws IOException if there is any error downloading the model, in the case it is needed
* @throws InterruptedException if the download of the model is stopped
* @throws ModelSpecsException if the model downloaded is not well specified in the config file
@@ -194,153 +139,7 @@ public static Stardist3D fromPretained(String pretrainedModel, String installDir
}
}
- private & NativeType> void checkInput(RandomAccessibleInterval image) {
- if (image.dimensionsAsLongArray().length == 3 && this.channels != 1)
- throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
- else if (image.dimensionsAsLongArray().length != 4 && this.channels != 1)
- throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
- else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != channels)
- throw new IllegalArgumentException("This Stardist3D model requires " + channels + " channels.");
- else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2)
- throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ.");
- }
- /**
- * Run the Stardist 3D model end to end, including pre- and post-processing.
- * @param
- * possible ImgLib2 data types of the input and output images
- * @param image
- * the input image that is going to be processed by Stardist3D
- * @return the final output of Stardist2D including pre- and post-processing
- * @throws ModelSpecsException if there is any error with the specs of the model
- * @throws LoadModelException if there is any error loading the model in Tensorflow Java
- * @throws LoadEngineException if there is any error loading Tensorflow Java engine
- * @throws IOException if there is any error with the files that are required to run the model
- * @throws RunModelException if there is any unexpected exception running the post-processing
- * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly
- */
- public & NativeType>
- RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException,
- LoadEngineException, IOException,
- RunModelException, InterruptedException {
- checkInput(image);
- if (image.dimensionsAsLongArray().length == 3) {
- image = Views.addDimension(image, 0, 0);
- image = Utils.transpose(image);
- image = Views.addDimension(image, 0, 0);
- } else if (image.dimensionsAsLongArray().length == 4) {
- image = Views.permute(image, 1, 2);
- image = Views.addDimension(image, 0, 1);
- image = Views.addDimension(image, 0, 0);
- image = Utils.transpose(image);
- }
-
- Tensor inputTensor = Tensor.build("input", "bzyxc", image);
- Tensor outputTensor = Tensor.buildEmptyTensor("output", "bzyxc");
-
- List> inputList = new ArrayList>();
- List> outputList = new ArrayList>();
- inputList.add(inputTensor);
- outputList.add(outputTensor);
-
- Model model = Model.createBioimageioModel(this.descriptor.getModelPath());
- model.loadModel();
- Processing processing = Processing.init(descriptor);
- inputList = processing.preprocess(inputList, false);
- model.runModel(inputList, outputList);
-
- return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData())));
- }
-
- /**
- * Execute stardist post-processing on the raw output of a Stardist 3D model
- * @param
- * possible data type of the input image
- * @param image
- * the raw output of a Stardist 3D model
- * @return the final output of a Stardist 3D model
- * @throws IOException if there is any error running the post-processing
- * @throws InterruptedException if the post-processing is interrupted
- */
- public & NativeType>
- RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException {
- Mamba mamba = new Mamba();
- String envPath = mamba.getEnvsDir() + File.separator + "stardist";
- String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME;
-
- GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST3D_METHOD_NAME, 1);
- LinkedHashMap nMap = new LinkedHashMap();
- Calendar cal = Calendar.getInstance();
- SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss");
- String dateString = sdf.format(cal.getTime());
- nMap.put("input_" + dateString, image);
- nMap.put("nms_thresh", nms_threshold);
- nMap.put("prob_thresh", prob_threshold);
- op.setInputs(nMap);
-
- RunMode rm;
- rm = RunMode.createRunMode(op);
- Map resMap = rm.runOP();
-
- List> rais = resMap.entrySet().stream()
- .filter(e -> {
- Object val = e.getValue();
- if (val instanceof RandomAccessibleInterval) return true;
- return false;
- }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList());
-
- return rais.get(0);
- }
-
- /**
- * Check whether everything that is needed for Stardist 3D is installed or not
- */
- public void checkRequirementsInstalled() {
- // TODO
- }
-
- /**
- * Check whether the requirements needed to run Stardist 3D are satisfied or not.
- * First checks if the corresponding Java DL engine is installed or not, then checks
- * if the Python environment needed for Stardist3D post processing is fine too.
- *
- * If anything is not installed, this method also installs it
- *
- * @throws IOException if there is any error downloading the DL engine or installing the micromamba environment
- * @throws InterruptedException if the installation is stopped
- * @throws RuntimeException if there is any unexpected error in the micromamba environment installation
- * @throws MambaInstallException if there is any error downloading or installing micromamba
- * @throws ArchiveException if there is any error decompressing the micromamba installer
- * @throws URISyntaxException if the URL to the micromamba installation is not correct
- */
- public static void installRequirements() throws IOException, InterruptedException,
- RuntimeException, MambaInstallException,
- ArchiveException, URISyntaxException {
- boolean installed = InstalledEngines.buildEnginesFinder()
- .checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0;
- if (!installed)
- EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true);
-
- Mamba mamba = new Mamba();
- boolean stardistPythonInstalled = false;
- try {
- stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS);
- } catch (MambaInstallException e) {
- mamba.installMicromamba();
- }
- if (!stardistPythonInstalled) {
- // TODO add logging for environment installation
- mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS);
- };
- String envPath = mamba.getEnvsDir() + File.separator + "stardist";
- String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME;
- if (!Paths.get(scriptPath).toFile().isFile()) {
- try (InputStream scriptStream = Stardist3D.class.getClassLoader()
- .getResourceAsStream(STARDIST3D_PATH_IN_RESOURCES + STARDIST3D_SCRIPT_NAME)){
- Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING);
- }
- }
- }
/**
* Main method to check functionality
@@ -364,10 +163,11 @@ public static void main(String[] args) throws IOException, InterruptedException,
URISyntaxException, LoadModelException {
Stardist3D.installRequirements();
Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
-
+
RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
- RandomAccessibleInterval res = model.predict(img);
+ Map> res = model.predict(img);
+ model.close();
System.out.println(true);
}
}
diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java b/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java
new file mode 100644
index 00000000..03acbfca
--- /dev/null
+++ b/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java
@@ -0,0 +1,373 @@
+/*-
+ * #%L
+ * Use deep learning frameworks from Java in an agnostic and isolated way.
+ * %%
+ * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers.
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+package io.bioimage.modelrunner.model;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URISyntaxException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.nio.file.StandardCopyOption;
+import java.text.SimpleDateFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Calendar;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import org.apache.commons.compress.archivers.ArchiveException;
+
+import io.bioimage.modelrunner.apposed.appose.Mamba;
+import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
+import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
+import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
+import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
+import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
+import io.bioimage.modelrunner.engine.installation.EngineInstall;
+import io.bioimage.modelrunner.exceptions.LoadEngineException;
+import io.bioimage.modelrunner.exceptions.LoadModelException;
+import io.bioimage.modelrunner.exceptions.RunModelException;
+import io.bioimage.modelrunner.model.processing.Processing;
+import io.bioimage.modelrunner.runmode.RunMode;
+import io.bioimage.modelrunner.runmode.ops.GenericOp;
+import io.bioimage.modelrunner.tensor.Tensor;
+import io.bioimage.modelrunner.tensor.Utils;
+import io.bioimage.modelrunner.utils.Constants;
+import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
+import net.imglib2.RandomAccessibleInterval;
+import net.imglib2.img.array.ArrayImgs;
+import net.imglib2.type.NativeType;
+import net.imglib2.type.numeric.RealType;
+import net.imglib2.type.numeric.real.FloatType;
+import net.imglib2.util.Cast;
+import net.imglib2.view.Views;
+
+/**
+ * This class provides an implementation of an API to run Stardist 3D models, including pre-
+ * and post-processing, with minimal configuration. The class allows for loading pretrained
+ * models from Bioimage.io and also includes methods for running predictions using 3D images.
+ *
+ * Stardist is a deep learning model specialized in detecting object shapes, particularly
+ * star-convex shapes in images, which is especially useful in biomedical imaging applications
+ * such as cell nuclei detection.
+ *
+ * The Stardist3D class includes methods for installing the necessary requirements (such as
+ * Python environments), running predictions, and handling Stardist-specific post-processing.
+ *
+ * TODO: add support for fine-tuning models and Mac ARM processors.
+ *
+ * Example usage:
+ *
+ * {@code
+ * Stardist3D.installRequirements();
+ * Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
+ * RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
+ * RandomAccessibleInterval res = model.predict(img);
+ * }
+ *
+ *
+ * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptor
+ * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory
+ * @see io.bioimage.modelrunner.exceptions.LoadModelException
+ * @see io.bioimage.modelrunner.exceptions.RunModelException
+ *
+ *TODO add fine tuning
+ *TODO add support for Mac arm
+ *
+ *@author Carlos Garcia
+ */
+public class Stardist3D_old {
+
+ private ModelDescriptor descriptor;
+
+ private final int channels;
+
+ private final float nms_threshold;
+
+ private final float prob_threshold;
+
+ private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});
+
+ private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});
+
+ private static final String STARDIST3D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/";
+
+ private static final String STARDIST3D_SCRIPT_NAME= "stardist_postprocessing_3D.py";
+
+ private static final String STARDIST3D_METHOD_NAME= "stardist_postprocessing";
+
+ private Stardist3D_old() {
+ this.channels = 1;
+ // TODO get from config??
+ this.nms_threshold = 0;
+ this.prob_threshold = 0;
+ }
+
+ private Stardist3D_old(ModelDescriptor descriptor) {
+ this.descriptor = descriptor;
+ Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist");
+ Map stardistConfig = (Map) stardistMap.get("config");
+ Map stardistThres = (Map) stardistMap.get("thresholds");
+ this.channels = (int) stardistConfig.get("n_channel_in");;
+ this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue();
+ this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue();
+ }
+
+ /**
+ * Initialize a Stardist3D using the format of the Bioiamge.io model zoo.
+ * @param modelPath
+ * path to the Bioimage.io model
+ * @return an instance of a Stardist3D model ready to be used
+ * @throws ModelSpecsException if the model configuration is incorrect in the specs file (rdf.yaml).
+ * @throws FileNotFoundException if the model file is not found in the specified path.
+ * @throws IOException if there is an issue reading the model file.
+ */
+ public static Stardist3D_old fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException {
+ ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME);
+ return new Stardist3D_old(descriptor);
+ }
+
+ /**
+ * Initialize one of the "official" pretrained Stardist 3D models.
+ * By default, the model will be installed in the "models" folder inside the application
+ * @param pretrainedModel
+ * the name of the pretrained model.
+ * @param forceInstall
+ * whether to force the installation or to try to look if the model has already been installed before
+ * @return an instance of a pretrained Stardist3D model ready to be used
+ * @throws IOException if there is any error downloading the model, in the case it is needed
+ * @throws InterruptedException if the download of the model is stopped
+ * @throws ModelSpecsException if the model downloaded is not well specified in the config file
+ */
+ public static Stardist3D_old fromPretained(String pretrainedModel, boolean forceInstall) throws IOException, InterruptedException, ModelSpecsException {
+ return fromPretained(pretrainedModel, new File("models").getAbsolutePath(), forceInstall);
+ }
+
+ /**
+ * Initialize one of the "official" pretrained Stardist 3D models
+ * @param pretrainedModel
+ * the name of the pretrained model.
+ * @param installDir
+ * the directory where the model wants to be installed
+ * @param forceInstall
+ * whether to force the installation or to try to look if the model has already been installed before
+ * @return an instance of a pretrained Stardist3D model ready to be used
+ * @throws IOException if there is any error downloading the model, in the case it is needed
+ * @throws InterruptedException if the download of the model is stopped
+ * @throws ModelSpecsException if the model downloaded is not well specified in the config file
+ */
+ public static Stardist3D_old fromPretained(String pretrainedModel, String installDir, boolean forceInstall) throws IOException,
+ InterruptedException,
+ ModelSpecsException {
+ if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet") && !forceInstall) {
+ ModelDescriptor md = ModelDescriptorFactory.getModelsAtLocalRepo().stream()
+ .filter(mm ->mm.getName().equals(pretrainedModel)).findFirst().orElse(null);
+ if (md != null) return new Stardist3D_old(md);
+ String path = BioimageioRepo.connect().downloadByName("StarDist Plant Nuclei 3D ResNet", installDir);
+ return Stardist3D_old.fromBioimageioModel(path);
+ } else if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet")) {
+ String path = BioimageioRepo.connect().downloadByName("StarDist Plant Nuclei 3D ResNet", installDir);
+ return Stardist3D_old.fromBioimageioModel(path);
+ } else {
+ throw new IllegalArgumentException("There is no Stardist3D model called: " + pretrainedModel);
+ }
+ }
+
+ private & NativeType> void checkInput(RandomAccessibleInterval image) {
+ if (image.dimensionsAsLongArray().length == 3 && this.channels != 1)
+ throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
+ else if (image.dimensionsAsLongArray().length != 4 && this.channels != 1)
+ throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ");
+ else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != channels)
+ throw new IllegalArgumentException("This Stardist3D model requires " + channels + " channels.");
+ else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2)
+ throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ.");
+ }
+
+ /**
+ * Run the Stardist 3D model end to end, including pre- and post-processing.
+ * @param
+ * possible ImgLib2 data types of the input and output images
+ * @param image
+ * the input image that is going to be processed by Stardist3D
+ * @return the final output of Stardist2D including pre- and post-processing
+ * @throws ModelSpecsException if there is any error with the specs of the model
+ * @throws LoadModelException if there is any error loading the model in Tensorflow Java
+ * @throws LoadEngineException if there is any error loading Tensorflow Java engine
+ * @throws IOException if there is any error with the files that are required to run the model
+ * @throws RunModelException if there is any unexpected exception running the post-processing
+ * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly
+ */
+ public & NativeType>
+ RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException,
+ LoadEngineException, IOException,
+ RunModelException, InterruptedException {
+ checkInput(image);
+ if (image.dimensionsAsLongArray().length == 3) {
+ image = Views.addDimension(image, 0, 0);
+ image = Utils.transpose(image);
+ image = Views.addDimension(image, 0, 0);
+ } else if (image.dimensionsAsLongArray().length == 4) {
+ image = Views.permute(image, 1, 2);
+ image = Views.addDimension(image, 0, 1);
+ image = Views.addDimension(image, 0, 0);
+ image = Utils.transpose(image);
+ }
+
+ Tensor inputTensor = Tensor.build("input", "bzyxc", image);
+ Tensor outputTensor = Tensor.buildEmptyTensor("output", "bzyxc");
+
+ List> inputList = new ArrayList>();
+ List> outputList = new ArrayList>();
+ inputList.add(inputTensor);
+ outputList.add(outputTensor);
+
+ Model model = Model.createBioimageioModel(this.descriptor.getModelPath());
+ model.loadModel();
+ Processing processing = Processing.init(descriptor);
+ inputList = processing.preprocess(inputList, false);
+ model.runModel(inputList, outputList);
+
+ return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData())));
+ }
+
+ /**
+ * Execute stardist post-processing on the raw output of a Stardist 3D model
+ * @param
+ * possible data type of the input image
+ * @param image
+ * the raw output of a Stardist 3D model
+ * @return the final output of a Stardist 3D model
+ * @throws IOException if there is any error running the post-processing
+ * @throws InterruptedException if the post-processing is interrupted
+ */
+ public & NativeType>
+ RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException {
+ Mamba mamba = new Mamba();
+ String envPath = mamba.getEnvsDir() + File.separator + "stardist";
+ String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME;
+
+ GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST3D_METHOD_NAME, 1);
+ LinkedHashMap nMap = new LinkedHashMap();
+ Calendar cal = Calendar.getInstance();
+ SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss");
+ String dateString = sdf.format(cal.getTime());
+ nMap.put("input_" + dateString, image);
+ nMap.put("nms_thresh", nms_threshold);
+ nMap.put("prob_thresh", prob_threshold);
+ op.setInputs(nMap);
+
+ RunMode rm;
+ rm = RunMode.createRunMode(op);
+ Map resMap = rm.runOP();
+
+ List> rais = resMap.entrySet().stream()
+ .filter(e -> {
+ Object val = e.getValue();
+ if (val instanceof RandomAccessibleInterval) return true;
+ return false;
+ }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList());
+
+ return rais.get(0);
+ }
+
+ /**
+ * Check whether everything that is needed for Stardist 3D is installed or not
+ */
+ public void checkRequirementsInstalled() {
+ // TODO
+ }
+
+ /**
+ * Check whether the requirements needed to run Stardist 3D are satisfied or not.
+ * First checks if the corresponding Java DL engine is installed or not, then checks
+ * if the Python environment needed for Stardist3D post processing is fine too.
+ *
+ * If anything is not installed, this method also installs it
+ *
+ * @throws IOException if there is any error downloading the DL engine or installing the micromamba environment
+ * @throws InterruptedException if the installation is stopped
+ * @throws RuntimeException if there is any unexpected error in the micromamba environment installation
+ * @throws MambaInstallException if there is any error downloading or installing micromamba
+ * @throws ArchiveException if there is any error decompressing the micromamba installer
+ * @throws URISyntaxException if the URL to the micromamba installation is not correct
+ */
+ public static void installRequirements() throws IOException, InterruptedException,
+ RuntimeException, MambaInstallException,
+ ArchiveException, URISyntaxException {
+ boolean installed = InstalledEngines.buildEnginesFinder()
+ .checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0;
+ if (!installed)
+ EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true);
+
+ Mamba mamba = new Mamba();
+ boolean stardistPythonInstalled = false;
+ try {
+ stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS);
+ } catch (MambaInstallException e) {
+ mamba.installMicromamba();
+ }
+ if (!stardistPythonInstalled) {
+ // TODO add logging for environment installation
+ mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS);
+ };
+ String envPath = mamba.getEnvsDir() + File.separator + "stardist";
+ String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME;
+ if (!Paths.get(scriptPath).toFile().isFile()) {
+ try (InputStream scriptStream = Stardist3D_old.class.getClassLoader()
+ .getResourceAsStream(STARDIST3D_PATH_IN_RESOURCES + STARDIST3D_SCRIPT_NAME)){
+ Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ }
+
+ /**
+ * Main method to check functionality
+ * @param args
+ * nothing
+ * @throws IOException nothing
+ * @throws InterruptedException nothing
+ * @throws RuntimeException nothing
+ * @throws MambaInstallException nothing
+ * @throws ModelSpecsException nothing
+ * @throws LoadEngineException nothing
+ * @throws RunModelException nothing
+ * @throws ArchiveException nothing
+ * @throws URISyntaxException nothing
+ * @throws LoadModelException nothing
+ */
+ public static void main(String[] args) throws IOException, InterruptedException,
+ RuntimeException, MambaInstallException,
+ ModelSpecsException, LoadEngineException,
+ RunModelException, ArchiveException,
+ URISyntaxException, LoadModelException {
+ Stardist3D_old.installRequirements();
+ Stardist3D_old model = Stardist3D_old.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
+
+ RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
+
+ RandomAccessibleInterval res = model.predict(img);
+ System.out.println(true);
+ }
+}
diff --git a/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java b/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java
index df09346b..3765e316 100644
--- a/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java
+++ b/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java
@@ -81,17 +81,13 @@ public abstract class StardistAbstract implements Closeable {
private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});
- private static final String COORDS_DTYPE_KEY = "coords_dtype";
+ private static final String SHM_NAME_KEY = "_shm_name";
- private static final String COORDS_SHAPE_KEY = "coords_shape";
+ private static final String DTYPE_KEY = "_dtype";
- private static final String POINTS_DTYPE_KEY = "points_dtype";
+ private static final String SHAPE_KEY = "_shape";
- private static final String POINTS_SHAPE_KEY = "points_shape";
-
- private static final String POINTS_KEY = "points";
-
- private static final String COORDS_KEY = "coords";
+ private static final String KEYS_KEY = "keys";
protected static final String LOAD_MODEL_CODE_ABSTRACT = ""
+ "if '%s' not in globals().keys():" + System.lineSeparator()
@@ -106,41 +102,67 @@ public abstract class StardistAbstract implements Closeable {
+ "if 'shared_memory' not in globals().keys():" + System.lineSeparator()
+ " from multiprocessing import shared_memory" + System.lineSeparator()
+ " globals()['shared_memory'] = shared_memory" + System.lineSeparator()
+ + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"" + System.lineSeparator()
+ "model = %s(None, name='%s', basedir='%s')" + System.lineSeparator()
- + "globals()['model'] = model";
+ + "globals()['model'] = model" + System.lineSeparator();
private static final String RUN_MODEL_CODE = ""
+ "output = model.predict_instances(im, return_predict=False)" + System.lineSeparator()
+ //+ "print(output)" + System.lineSeparator()
+ + "print(type(output))" + System.lineSeparator()
+ + "if type(output) == np.ndarray:" + System.lineSeparator()
+ + " im[:] = output" + System.lineSeparator()
+ + " im[:] = output" + System.lineSeparator()
+ + " if os.name == 'nt':" + System.lineSeparator()
+ + " im_shm.close()" + System.lineSeparator()
+ + " im_shm.unlink()" + System.lineSeparator()
+ + "if type(output) != list and type(output) != tuple:" + System.lineSeparator()
+ + " raise TypeError('StarDist output should be a list of a np.ndarray')" + System.lineSeparator()
+ + "if type(output[0]) != np.ndarray:" + System.lineSeparator()
+ + " raise TypeError('If the StarDist output is a list, the first entry should be a np.ndarray')" + System.lineSeparator()
+ "im[:] = output[0]" + System.lineSeparator()
- + "if output[1]['" + POINTS_KEY + "'].nbytes == 0:" + System.lineSeparator()
- + " task.outputs['" + POINTS_SHAPE_KEY + "'] = None" + System.lineSeparator()
- + "else:" + System.lineSeparator()
- + " task.outputs['" + POINTS_SHAPE_KEY + "'] = output[1]['" + POINTS_KEY + "'].shape" + System.lineSeparator()
- + " task.outputs['"+ POINTS_DTYPE_KEY + "'] = output[1]['" + POINTS_KEY + "'].dtype" + System.lineSeparator()
- + " points_shm = "
- + " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['" + POINTS_KEY + "'].nbytes)" + System.lineSeparator()
- + " shared_points = np.ndarray(output[1]['" + POINTS_KEY + "'].shape, dtype=output[1]['" + POINTS_KEY + "'].dtype, buffer=points_shm.buf)" + System.lineSeparator()
- + " globals()['shared_points'] = shared_points" + System.lineSeparator()
- + "if output[1]['" + COORDS_KEY + "'].nbytes == 0:" + System.lineSeparator()
- + " task.outputs['" + COORDS_SHAPE_KEY + "'] = None" + System.lineSeparator()
- + "else:" + System.lineSeparator()
- + " task.outputs['" + COORDS_SHAPE_KEY + "'] = output[1]['" + COORDS_KEY + "'].shape" + System.lineSeparator()
- + " task.outputs['" + COORDS_DTYPE_KEY + "'] = output[1]['" + COORDS_KEY + "'].dtype" + System.lineSeparator()
- + " coords_shm = "
- + " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['" + COORDS_KEY + "'].nbytes)" + System.lineSeparator()
- + " shared_coords = np.ndarray(output[1]['" + COORDS_KEY + "'].shape, dtype=output[1]['" + COORDS_KEY + "'].dtype, buffer=coords_shm.buf)" + System.lineSeparator()
- + " globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ + "if len(output) > 1 and type(output[1]) != dict:" + System.lineSeparator()
+ + " raise TypeError('If the StarDist output is a list, the second entry needs to be a dict.')" + System.lineSeparator()
+ + "task.outputs['" + KEYS_KEY + "'] = list(output[1].keys())" + System.lineSeparator()
+ + "shm_list = []" + System.lineSeparator()
+ + "np_list = []" + System.lineSeparator()
+ + "for kk, vv in output[1].items():" + System.lineSeparator()
+ + " print(kk)" + System.lineSeparator()
+ + " if type(vv) != np.ndarray:" + System.lineSeparator()
+ + " task.update('Output ' + kk + ' is not a np.ndarray. Only np.ndarrays supported.')" + System.lineSeparator()
+ + " print(type(vv))" + System.lineSeparator()
+ + " continue" + System.lineSeparator()
+ + " if output[1][kk].nbytes == 0:" + System.lineSeparator()
+ + " task.outputs[kk] = None" + System.lineSeparator()
+ + " else:" + System.lineSeparator()
+ + " task.outputs[kk + '" + SHAPE_KEY + "'] = output[1][kk].shape" + System.lineSeparator()
+ + " print(type(output[1][kk].shape))" + System.lineSeparator()
+ + " task.outputs[kk + '"+ DTYPE_KEY + "'] = str(output[1][kk].dtype)" + System.lineSeparator()
+ + " print(type(output[1][kk].dtype))" + System.lineSeparator()
+ + " shm = shared_memory.SharedMemory(create=True, size=output[1][kk].nbytes)" + System.lineSeparator()
+ + " task.outputs[kk + '"+ SHM_NAME_KEY + "'] = shm.name" + System.lineSeparator()
+ + " print(type(shm.name))" + System.lineSeparator()
+ + " shm_list.append(shm)" + System.lineSeparator()
+ + " aa = np.ndarray(output[1][kk].shape, dtype=output[1][kk].dtype, buffer=shm.buf)" + System.lineSeparator()
+ + " aa[:] = output[1][kk]" + System.lineSeparator()
+ + " np_list.append(aa)" + System.lineSeparator()
+ + "print('dd')" + System.lineSeparator()
+ + "globals()['shm_list'] = shm_list" + System.lineSeparator()
+ + "globals()['np_list'] = np_list" + System.lineSeparator()
+
+
+ "if os.name == 'nt':" + System.lineSeparator()
+ " im_shm.close()" + System.lineSeparator()
+ " im_shm.unlink()" + System.lineSeparator();
private static final String CLOSE_SHM_CODE = ""
- + "if 'points_shm' in globals().keys():" + System.lineSeparator()
- + " points_shm.close()" + System.lineSeparator()
- + " points_shm.unlink()" + System.lineSeparator()
- + "if 'coords_shm' in globals().keys():" + System.lineSeparator()
- + " coords_shm.close()" + System.lineSeparator()
- + " coords_shm.unlink()" + System.lineSeparator();
+ + "if 'np_list' in globals().keys():" + System.lineSeparator()
+ + " for a in np_list:" + System.lineSeparator()
+ + " del a" + System.lineSeparator()
+ + "if 'shm_list' in globals().keys():" + System.lineSeparator()
+ + " for s in shm_list:" + System.lineSeparator()
+ + " s.unlink()" + System.lineSeparator()
+ + " del s" + System.lineSeparator();
protected abstract String createImportsCode();
@@ -265,11 +287,14 @@ Map> reconstructOutputs(Task task, String sh
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()),
Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI())));
+ shma.close();
outs.put("mask", maskCopy);
- outs.put(POINTS_KEY, reconstructPoints(task, shm_points_id));
- outs.put(COORDS_KEY, reconstructCoord(task, shm_coords_id));
- shma.close();
+ if (task.outputs.get(KEYS_KEY) != null) {
+ for (String kk : (List) task.outputs.get(KEYS_KEY)) {
+ outs.put("", reconstruct(task, kk));
+ }
+ }
if (PlatformDetection.isWindows()) {
Task closeSHMTask = python.task(CLOSE_SHM_CODE);
@@ -279,17 +304,18 @@ Map> reconstructOutputs(Task task, String sh
}
private & NativeType>
- RandomAccessibleInterval reconstructCoord(Task task, String shm_coords_id) throws IOException {
-
- String coords_dtype = (String) task.outputs.get(COORDS_DTYPE_KEY);
- List coords_shape = (List) task.outputs.get(COORDS_SHAPE_KEY);
+ RandomAccessibleInterval reconstruct(Task task, String key) throws IOException {
+
+ String shm_name = (String) task.outputs.get(key + SHM_NAME_KEY);
+ String coords_dtype = (String) task.outputs.get(key + DTYPE_KEY);
+ List coords_shape = (List) task.outputs.get(key + SHAPE_KEY);
if (coords_shape == null)
return null;
long[] coordsSh = new long[coords_shape.size()];
for (int i = 0; i < coordsSh.length; i ++)
coordsSh[i] = coords_shape.get(i).longValue();
- SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh,
+ SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_name, coordsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false);
Map> outs = new HashMap>();
@@ -304,29 +330,6 @@ RandomAccessibleInterval reconstructCoord(Task task, String shm_coords_id) th
return coordsCopy;
}
- private & NativeType>
- RandomAccessibleInterval reconstructPoints(Task task, String shm_points_id) throws IOException {
-
- String points_dtype = (String) task.outputs.get(POINTS_DTYPE_KEY);
- List points_shape = (List) task.outputs.get(POINTS_SHAPE_KEY);
- if (points_shape == null)
- return null;
-
-
- long[] pointsSh = new long[points_shape.size()];
- for (int i = 0; i < pointsSh.length; i ++)
- pointsSh[i] = points_shape.get(i).longValue();
- SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh,
- Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false);
-
- // TODO I do not understand why is complaining when the types align perfectly
- RandomAccessibleInterval pointsRAI = shmPoints.getSharedRAI();
- RandomAccessibleInterval pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI),
- Util.getTypeFromInterval(Cast.unchecked(pointsRAI)));
- shmPoints.close();
- return pointsCopy;
- }
-
/**
* Check whether everything that is needed for Stardist 2D is installed or not
*/