From 41fbece2d28982e01c2cdfb9c91b014e0bbf0742 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 16 Jan 2025 14:55:57 +0100 Subject: [PATCH] keeep iterating to manage to get a stable stardist --- .../modelrunner/model/Stardist2D.java | 2 +- .../modelrunner/model/Stardist3D.java | 272 ++----------- .../modelrunner/model/Stardist3D_old.java | 373 ++++++++++++++++++ .../modelrunner/model/StardistAbstract.java | 131 +++--- 4 files changed, 477 insertions(+), 301 deletions(-) create mode 100644 src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java index 7c9c6020..e656a5cb 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java @@ -51,7 +51,7 @@ */ public class Stardist2D extends StardistAbstract { - private static String MODULE_NAME = "Stardist2D"; + private static String MODULE_NAME = "StarDist2D"; private Stardist2D(String modelName, String baseDir) throws IOException, ModelSpecsException { super(modelName, baseDir); diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java index c1e5675f..28b429a9 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist3D.java @@ -22,140 +22,85 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; -import java.io.InputStream; import java.net.URISyntaxException; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardCopyOption; -import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Calendar; -import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.commons.compress.archivers.ArchiveException; -import io.bioimage.modelrunner.apposed.appose.Mamba; import io.bioimage.modelrunner.apposed.appose.MambaInstallException; import io.bioimage.modelrunner.bioimageio.BioimageioRepo; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException; -import io.bioimage.modelrunner.engine.installation.EngineInstall; import io.bioimage.modelrunner.exceptions.LoadEngineException; import io.bioimage.modelrunner.exceptions.LoadModelException; import io.bioimage.modelrunner.exceptions.RunModelException; -import io.bioimage.modelrunner.model.processing.Processing; -import io.bioimage.modelrunner.runmode.RunMode; -import io.bioimage.modelrunner.runmode.ops.GenericOp; -import io.bioimage.modelrunner.tensor.Tensor; -import io.bioimage.modelrunner.tensor.Utils; import io.bioimage.modelrunner.utils.Constants; -import io.bioimage.modelrunner.versionmanagement.InstalledEngines; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; -import net.imglib2.util.Cast; -import net.imglib2.view.Views; /** - * This class provides an implementation of an API to run Stardist 3D models, including pre- - * and post-processing, with minimal configuration. The class allows for loading pretrained - * models from Bioimage.io and also includes methods for running predictions using 3D images. - * - *

Stardist is a deep learning model specialized in detecting object shapes, particularly - * star-convex shapes in images, which is especially useful in biomedical imaging applications - * such as cell nuclei detection.

- * - *

The Stardist3D class includes methods for installing the necessary requirements (such as - * Python environments), running predictions, and handling Stardist-specific post-processing.

- * - *

TODO: add support for fine-tuning models and Mac ARM processors.

- * - * Example usage: - *
- * {@code
- * Stardist3D.installRequirements();
- * Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
- * RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
- * RandomAccessibleInterval res = model.predict(img);
- * }
- * 
- * - * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptor - * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory - * @see io.bioimage.modelrunner.exceptions.LoadModelException - * @see io.bioimage.modelrunner.exceptions.RunModelException + * Implementation of an API to run Stardist 3D models out of the box with little configuration. * *TODO add fine tuning - *TODO add support for Mac arm * *@author Carlos Garcia */ -public class Stardist3D { - - private ModelDescriptor descriptor; - - private final int channels; - - private final float nms_threshold; +public class Stardist3D extends StardistAbstract { - private final float prob_threshold; + private static String MODULE_NAME = "StarDist3D"; - private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); - - private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); - - private static final String STARDIST3D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/"; - - private static final String STARDIST3D_SCRIPT_NAME= "stardist_postprocessing_3D.py"; - - private static final String STARDIST3D_METHOD_NAME= "stardist_postprocessing"; - - private Stardist3D() { - this.channels = 1; - // TODO get from config?? - this.nms_threshold = 0; - this.prob_threshold = 0; + private Stardist3D(String modelName, String baseDir) throws IOException, ModelSpecsException { + super(modelName, baseDir); } - private Stardist3D(ModelDescriptor descriptor) { - this.descriptor = descriptor; - Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); - Map stardistConfig = (Map) stardistMap.get("config"); - Map stardistThres = (Map) stardistMap.get("thresholds"); - this.channels = (int) stardistConfig.get("n_channel_in");; - this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); - this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); + private Stardist3D(ModelDescriptor descriptor) throws IOException, ModelSpecsException { + super(descriptor); + } + + @Override + protected String createImportsCode() { + return String.format(LOAD_MODEL_CODE_ABSTRACT, MODULE_NAME, MODULE_NAME, + MODULE_NAME, MODULE_NAME, MODULE_NAME, this.name, this.basedir); + } + + @Override + protected & NativeType> void checkInput(RandomAccessibleInterval image) { + if (image.dimensionsAsLongArray().length == 3 && this.nChannels != 1) + throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); + else if (image.dimensionsAsLongArray().length != 4 && this.nChannels != 1) + throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); + else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != nChannels) + throw new IllegalArgumentException("This Stardist3D model requires " + nChannels + " channels."); + else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2) + throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ."); } /** - * Initialize a Stardist3D using the format of the Bioiamge.io model zoo. + * Initialize a Stardist2D using the format of the Bioiamge.io model zoo. * @param modelPath * path to the Bioimage.io model - * @return an instance of a Stardist3D model ready to be used - * @throws ModelSpecsException if the model configuration is incorrect in the specs file (rdf.yaml). - * @throws FileNotFoundException if the model file is not found in the specified path. - * @throws IOException if there is an issue reading the model file. + * @return an instance of a Stardist2D model ready to be used + * @throws ModelSpecsException If there is any error in the configuration of the specs rdf.yaml file of the Bioimage.io. + * @throws FileNotFoundException If the model file is not found. + * @throws IOException If there's an I/O error. */ public static Stardist3D fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException { ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME); return new Stardist3D(descriptor); } - + /** - * Initialize one of the "official" pretrained Stardist 3D models. + * Initialize one of the "official" pretrained Stardist 2D models. * By default, the model will be installed in the "models" folder inside the application * @param pretrainedModel * the name of the pretrained model. * @param forceInstall * whether to force the installation or to try to look if the model has already been installed before - * @return an instance of a pretrained Stardist3D model ready to be used + * @return an instance of a pretrained Stardist2D model ready to be used * @throws IOException if there is any error downloading the model, in the case it is needed * @throws InterruptedException if the download of the model is stopped * @throws ModelSpecsException if the model downloaded is not well specified in the config file @@ -194,153 +139,7 @@ public static Stardist3D fromPretained(String pretrainedModel, String installDir } } - private & NativeType> void checkInput(RandomAccessibleInterval image) { - if (image.dimensionsAsLongArray().length == 3 && this.channels != 1) - throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); - else if (image.dimensionsAsLongArray().length != 4 && this.channels != 1) - throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); - else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != channels) - throw new IllegalArgumentException("This Stardist3D model requires " + channels + " channels."); - else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2) - throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ."); - } - /** - * Run the Stardist 3D model end to end, including pre- and post-processing. - * @param - * possible ImgLib2 data types of the input and output images - * @param image - * the input image that is going to be processed by Stardist3D - * @return the final output of Stardist2D including pre- and post-processing - * @throws ModelSpecsException if there is any error with the specs of the model - * @throws LoadModelException if there is any error loading the model in Tensorflow Java - * @throws LoadEngineException if there is any error loading Tensorflow Java engine - * @throws IOException if there is any error with the files that are required to run the model - * @throws RunModelException if there is any unexpected exception running the post-processing - * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly - */ - public & NativeType> - RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException, - LoadEngineException, IOException, - RunModelException, InterruptedException { - checkInput(image); - if (image.dimensionsAsLongArray().length == 3) { - image = Views.addDimension(image, 0, 0); - image = Utils.transpose(image); - image = Views.addDimension(image, 0, 0); - } else if (image.dimensionsAsLongArray().length == 4) { - image = Views.permute(image, 1, 2); - image = Views.addDimension(image, 0, 1); - image = Views.addDimension(image, 0, 0); - image = Utils.transpose(image); - } - - Tensor inputTensor = Tensor.build("input", "bzyxc", image); - Tensor outputTensor = Tensor.buildEmptyTensor("output", "bzyxc"); - - List> inputList = new ArrayList>(); - List> outputList = new ArrayList>(); - inputList.add(inputTensor); - outputList.add(outputTensor); - - Model model = Model.createBioimageioModel(this.descriptor.getModelPath()); - model.loadModel(); - Processing processing = Processing.init(descriptor); - inputList = processing.preprocess(inputList, false); - model.runModel(inputList, outputList); - - return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData()))); - } - - /** - * Execute stardist post-processing on the raw output of a Stardist 3D model - * @param - * possible data type of the input image - * @param image - * the raw output of a Stardist 3D model - * @return the final output of a Stardist 3D model - * @throws IOException if there is any error running the post-processing - * @throws InterruptedException if the post-processing is interrupted - */ - public & NativeType> - RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException { - Mamba mamba = new Mamba(); - String envPath = mamba.getEnvsDir() + File.separator + "stardist"; - String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME; - - GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST3D_METHOD_NAME, 1); - LinkedHashMap nMap = new LinkedHashMap(); - Calendar cal = Calendar.getInstance(); - SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss"); - String dateString = sdf.format(cal.getTime()); - nMap.put("input_" + dateString, image); - nMap.put("nms_thresh", nms_threshold); - nMap.put("prob_thresh", prob_threshold); - op.setInputs(nMap); - - RunMode rm; - rm = RunMode.createRunMode(op); - Map resMap = rm.runOP(); - - List> rais = resMap.entrySet().stream() - .filter(e -> { - Object val = e.getValue(); - if (val instanceof RandomAccessibleInterval) return true; - return false; - }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList()); - - return rais.get(0); - } - - /** - * Check whether everything that is needed for Stardist 3D is installed or not - */ - public void checkRequirementsInstalled() { - // TODO - } - - /** - * Check whether the requirements needed to run Stardist 3D are satisfied or not. - * First checks if the corresponding Java DL engine is installed or not, then checks - * if the Python environment needed for Stardist3D post processing is fine too. - * - * If anything is not installed, this method also installs it - * - * @throws IOException if there is any error downloading the DL engine or installing the micromamba environment - * @throws InterruptedException if the installation is stopped - * @throws RuntimeException if there is any unexpected error in the micromamba environment installation - * @throws MambaInstallException if there is any error downloading or installing micromamba - * @throws ArchiveException if there is any error decompressing the micromamba installer - * @throws URISyntaxException if the URL to the micromamba installation is not correct - */ - public static void installRequirements() throws IOException, InterruptedException, - RuntimeException, MambaInstallException, - ArchiveException, URISyntaxException { - boolean installed = InstalledEngines.buildEnginesFinder() - .checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0; - if (!installed) - EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true); - - Mamba mamba = new Mamba(); - boolean stardistPythonInstalled = false; - try { - stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS); - } catch (MambaInstallException e) { - mamba.installMicromamba(); - } - if (!stardistPythonInstalled) { - // TODO add logging for environment installation - mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS); - }; - String envPath = mamba.getEnvsDir() + File.separator + "stardist"; - String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME; - if (!Paths.get(scriptPath).toFile().isFile()) { - try (InputStream scriptStream = Stardist3D.class.getClassLoader() - .getResourceAsStream(STARDIST3D_PATH_IN_RESOURCES + STARDIST3D_SCRIPT_NAME)){ - Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING); - } - } - } /** * Main method to check functionality @@ -364,10 +163,11 @@ public static void main(String[] args) throws IOException, InterruptedException, URISyntaxException, LoadModelException { Stardist3D.installRequirements(); Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false); - + RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66}); - RandomAccessibleInterval res = model.predict(img); + Map> res = model.predict(img); + model.close(); System.out.println(true); } } diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java b/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java new file mode 100644 index 00000000..03acbfca --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist3D_old.java @@ -0,0 +1,373 @@ +/*- + * #%L + * Use deep learning frameworks from Java in an agnostic and isolated way. + * %% + * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.model; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.compress.archivers.ArchiveException; + +import io.bioimage.modelrunner.apposed.appose.Mamba; +import io.bioimage.modelrunner.apposed.appose.MambaInstallException; +import io.bioimage.modelrunner.bioimageio.BioimageioRepo; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; +import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException; +import io.bioimage.modelrunner.engine.installation.EngineInstall; +import io.bioimage.modelrunner.exceptions.LoadEngineException; +import io.bioimage.modelrunner.exceptions.LoadModelException; +import io.bioimage.modelrunner.exceptions.RunModelException; +import io.bioimage.modelrunner.model.processing.Processing; +import io.bioimage.modelrunner.runmode.RunMode; +import io.bioimage.modelrunner.runmode.ops.GenericOp; +import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.tensor.Utils; +import io.bioimage.modelrunner.utils.Constants; +import io.bioimage.modelrunner.versionmanagement.InstalledEngines; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; +import net.imglib2.view.Views; + +/** + * This class provides an implementation of an API to run Stardist 3D models, including pre- + * and post-processing, with minimal configuration. The class allows for loading pretrained + * models from Bioimage.io and also includes methods for running predictions using 3D images. + * + *

Stardist is a deep learning model specialized in detecting object shapes, particularly + * star-convex shapes in images, which is especially useful in biomedical imaging applications + * such as cell nuclei detection.

+ * + *

The Stardist3D class includes methods for installing the necessary requirements (such as + * Python environments), running predictions, and handling Stardist-specific post-processing.

+ * + *

TODO: add support for fine-tuning models and Mac ARM processors.

+ * + * Example usage: + *
+ * {@code
+ * Stardist3D.installRequirements();
+ * Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
+ * RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66});
+ * RandomAccessibleInterval res = model.predict(img);
+ * }
+ * 
+ * + * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptor + * @see io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory + * @see io.bioimage.modelrunner.exceptions.LoadModelException + * @see io.bioimage.modelrunner.exceptions.RunModelException + * + *TODO add fine tuning + *TODO add support for Mac arm + * + *@author Carlos Garcia + */ +public class Stardist3D_old { + + private ModelDescriptor descriptor; + + private final int channels; + + private final float nms_threshold; + + private final float prob_threshold; + + private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); + + private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); + + private static final String STARDIST3D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/"; + + private static final String STARDIST3D_SCRIPT_NAME= "stardist_postprocessing_3D.py"; + + private static final String STARDIST3D_METHOD_NAME= "stardist_postprocessing"; + + private Stardist3D_old() { + this.channels = 1; + // TODO get from config?? + this.nms_threshold = 0; + this.prob_threshold = 0; + } + + private Stardist3D_old(ModelDescriptor descriptor) { + this.descriptor = descriptor; + Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); + Map stardistConfig = (Map) stardistMap.get("config"); + Map stardistThres = (Map) stardistMap.get("thresholds"); + this.channels = (int) stardistConfig.get("n_channel_in");; + this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); + this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); + } + + /** + * Initialize a Stardist3D using the format of the Bioiamge.io model zoo. + * @param modelPath + * path to the Bioimage.io model + * @return an instance of a Stardist3D model ready to be used + * @throws ModelSpecsException if the model configuration is incorrect in the specs file (rdf.yaml). + * @throws FileNotFoundException if the model file is not found in the specified path. + * @throws IOException if there is an issue reading the model file. + */ + public static Stardist3D_old fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException { + ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME); + return new Stardist3D_old(descriptor); + } + + /** + * Initialize one of the "official" pretrained Stardist 3D models. + * By default, the model will be installed in the "models" folder inside the application + * @param pretrainedModel + * the name of the pretrained model. + * @param forceInstall + * whether to force the installation or to try to look if the model has already been installed before + * @return an instance of a pretrained Stardist3D model ready to be used + * @throws IOException if there is any error downloading the model, in the case it is needed + * @throws InterruptedException if the download of the model is stopped + * @throws ModelSpecsException if the model downloaded is not well specified in the config file + */ + public static Stardist3D_old fromPretained(String pretrainedModel, boolean forceInstall) throws IOException, InterruptedException, ModelSpecsException { + return fromPretained(pretrainedModel, new File("models").getAbsolutePath(), forceInstall); + } + + /** + * Initialize one of the "official" pretrained Stardist 3D models + * @param pretrainedModel + * the name of the pretrained model. + * @param installDir + * the directory where the model wants to be installed + * @param forceInstall + * whether to force the installation or to try to look if the model has already been installed before + * @return an instance of a pretrained Stardist3D model ready to be used + * @throws IOException if there is any error downloading the model, in the case it is needed + * @throws InterruptedException if the download of the model is stopped + * @throws ModelSpecsException if the model downloaded is not well specified in the config file + */ + public static Stardist3D_old fromPretained(String pretrainedModel, String installDir, boolean forceInstall) throws IOException, + InterruptedException, + ModelSpecsException { + if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet") && !forceInstall) { + ModelDescriptor md = ModelDescriptorFactory.getModelsAtLocalRepo().stream() + .filter(mm ->mm.getName().equals(pretrainedModel)).findFirst().orElse(null); + if (md != null) return new Stardist3D_old(md); + String path = BioimageioRepo.connect().downloadByName("StarDist Plant Nuclei 3D ResNet", installDir); + return Stardist3D_old.fromBioimageioModel(path); + } else if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet")) { + String path = BioimageioRepo.connect().downloadByName("StarDist Plant Nuclei 3D ResNet", installDir); + return Stardist3D_old.fromBioimageioModel(path); + } else { + throw new IllegalArgumentException("There is no Stardist3D model called: " + pretrainedModel); + } + } + + private & NativeType> void checkInput(RandomAccessibleInterval image) { + if (image.dimensionsAsLongArray().length == 3 && this.channels != 1) + throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); + else if (image.dimensionsAsLongArray().length != 4 && this.channels != 1) + throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYCZ"); + else if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[2] != channels) + throw new IllegalArgumentException("This Stardist3D model requires " + channels + " channels."); + else if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2) + throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYCZ."); + } + + /** + * Run the Stardist 3D model end to end, including pre- and post-processing. + * @param + * possible ImgLib2 data types of the input and output images + * @param image + * the input image that is going to be processed by Stardist3D + * @return the final output of Stardist2D including pre- and post-processing + * @throws ModelSpecsException if there is any error with the specs of the model + * @throws LoadModelException if there is any error loading the model in Tensorflow Java + * @throws LoadEngineException if there is any error loading Tensorflow Java engine + * @throws IOException if there is any error with the files that are required to run the model + * @throws RunModelException if there is any unexpected exception running the post-processing + * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly + */ + public & NativeType> + RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException, + LoadEngineException, IOException, + RunModelException, InterruptedException { + checkInput(image); + if (image.dimensionsAsLongArray().length == 3) { + image = Views.addDimension(image, 0, 0); + image = Utils.transpose(image); + image = Views.addDimension(image, 0, 0); + } else if (image.dimensionsAsLongArray().length == 4) { + image = Views.permute(image, 1, 2); + image = Views.addDimension(image, 0, 1); + image = Views.addDimension(image, 0, 0); + image = Utils.transpose(image); + } + + Tensor inputTensor = Tensor.build("input", "bzyxc", image); + Tensor outputTensor = Tensor.buildEmptyTensor("output", "bzyxc"); + + List> inputList = new ArrayList>(); + List> outputList = new ArrayList>(); + inputList.add(inputTensor); + outputList.add(outputTensor); + + Model model = Model.createBioimageioModel(this.descriptor.getModelPath()); + model.loadModel(); + Processing processing = Processing.init(descriptor); + inputList = processing.preprocess(inputList, false); + model.runModel(inputList, outputList); + + return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData()))); + } + + /** + * Execute stardist post-processing on the raw output of a Stardist 3D model + * @param + * possible data type of the input image + * @param image + * the raw output of a Stardist 3D model + * @return the final output of a Stardist 3D model + * @throws IOException if there is any error running the post-processing + * @throws InterruptedException if the post-processing is interrupted + */ + public & NativeType> + RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException { + Mamba mamba = new Mamba(); + String envPath = mamba.getEnvsDir() + File.separator + "stardist"; + String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME; + + GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST3D_METHOD_NAME, 1); + LinkedHashMap nMap = new LinkedHashMap(); + Calendar cal = Calendar.getInstance(); + SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss"); + String dateString = sdf.format(cal.getTime()); + nMap.put("input_" + dateString, image); + nMap.put("nms_thresh", nms_threshold); + nMap.put("prob_thresh", prob_threshold); + op.setInputs(nMap); + + RunMode rm; + rm = RunMode.createRunMode(op); + Map resMap = rm.runOP(); + + List> rais = resMap.entrySet().stream() + .filter(e -> { + Object val = e.getValue(); + if (val instanceof RandomAccessibleInterval) return true; + return false; + }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList()); + + return rais.get(0); + } + + /** + * Check whether everything that is needed for Stardist 3D is installed or not + */ + public void checkRequirementsInstalled() { + // TODO + } + + /** + * Check whether the requirements needed to run Stardist 3D are satisfied or not. + * First checks if the corresponding Java DL engine is installed or not, then checks + * if the Python environment needed for Stardist3D post processing is fine too. + * + * If anything is not installed, this method also installs it + * + * @throws IOException if there is any error downloading the DL engine or installing the micromamba environment + * @throws InterruptedException if the installation is stopped + * @throws RuntimeException if there is any unexpected error in the micromamba environment installation + * @throws MambaInstallException if there is any error downloading or installing micromamba + * @throws ArchiveException if there is any error decompressing the micromamba installer + * @throws URISyntaxException if the URL to the micromamba installation is not correct + */ + public static void installRequirements() throws IOException, InterruptedException, + RuntimeException, MambaInstallException, + ArchiveException, URISyntaxException { + boolean installed = InstalledEngines.buildEnginesFinder() + .checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0; + if (!installed) + EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true); + + Mamba mamba = new Mamba(); + boolean stardistPythonInstalled = false; + try { + stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS); + } catch (MambaInstallException e) { + mamba.installMicromamba(); + } + if (!stardistPythonInstalled) { + // TODO add logging for environment installation + mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS); + }; + String envPath = mamba.getEnvsDir() + File.separator + "stardist"; + String scriptPath = envPath + File.separator + STARDIST3D_SCRIPT_NAME; + if (!Paths.get(scriptPath).toFile().isFile()) { + try (InputStream scriptStream = Stardist3D_old.class.getClassLoader() + .getResourceAsStream(STARDIST3D_PATH_IN_RESOURCES + STARDIST3D_SCRIPT_NAME)){ + Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING); + } + } + } + + /** + * Main method to check functionality + * @param args + * nothing + * @throws IOException nothing + * @throws InterruptedException nothing + * @throws RuntimeException nothing + * @throws MambaInstallException nothing + * @throws ModelSpecsException nothing + * @throws LoadEngineException nothing + * @throws RunModelException nothing + * @throws ArchiveException nothing + * @throws URISyntaxException nothing + * @throws LoadModelException nothing + */ + public static void main(String[] args) throws IOException, InterruptedException, + RuntimeException, MambaInstallException, + ModelSpecsException, LoadEngineException, + RunModelException, ArchiveException, + URISyntaxException, LoadModelException { + Stardist3D_old.installRequirements(); + Stardist3D_old model = Stardist3D_old.fromPretained("StarDist Plant Nuclei 3D ResNet", false); + + RandomAccessibleInterval img = ArrayImgs.floats(new long[] {116, 120, 66}); + + RandomAccessibleInterval res = model.predict(img); + System.out.println(true); + } +} diff --git a/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java b/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java index df09346b..3765e316 100644 --- a/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java +++ b/src/main/java/io/bioimage/modelrunner/model/StardistAbstract.java @@ -81,17 +81,13 @@ public abstract class StardistAbstract implements Closeable { private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); - private static final String COORDS_DTYPE_KEY = "coords_dtype"; + private static final String SHM_NAME_KEY = "_shm_name"; - private static final String COORDS_SHAPE_KEY = "coords_shape"; + private static final String DTYPE_KEY = "_dtype"; - private static final String POINTS_DTYPE_KEY = "points_dtype"; + private static final String SHAPE_KEY = "_shape"; - private static final String POINTS_SHAPE_KEY = "points_shape"; - - private static final String POINTS_KEY = "points"; - - private static final String COORDS_KEY = "coords"; + private static final String KEYS_KEY = "keys"; protected static final String LOAD_MODEL_CODE_ABSTRACT = "" + "if '%s' not in globals().keys():" + System.lineSeparator() @@ -106,41 +102,67 @@ public abstract class StardistAbstract implements Closeable { + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + " from multiprocessing import shared_memory" + System.lineSeparator() + " globals()['shared_memory'] = shared_memory" + System.lineSeparator() + + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"" + System.lineSeparator() + "model = %s(None, name='%s', basedir='%s')" + System.lineSeparator() - + "globals()['model'] = model"; + + "globals()['model'] = model" + System.lineSeparator(); private static final String RUN_MODEL_CODE = "" + "output = model.predict_instances(im, return_predict=False)" + System.lineSeparator() + //+ "print(output)" + System.lineSeparator() + + "print(type(output))" + System.lineSeparator() + + "if type(output) == np.ndarray:" + System.lineSeparator() + + " im[:] = output" + System.lineSeparator() + + " im[:] = output" + System.lineSeparator() + + " if os.name == 'nt':" + System.lineSeparator() + + " im_shm.close()" + System.lineSeparator() + + " im_shm.unlink()" + System.lineSeparator() + + "if type(output) != list and type(output) != tuple:" + System.lineSeparator() + + " raise TypeError('StarDist output should be a list of a np.ndarray')" + System.lineSeparator() + + "if type(output[0]) != np.ndarray:" + System.lineSeparator() + + " raise TypeError('If the StarDist output is a list, the first entry should be a np.ndarray')" + System.lineSeparator() + "im[:] = output[0]" + System.lineSeparator() - + "if output[1]['" + POINTS_KEY + "'].nbytes == 0:" + System.lineSeparator() - + " task.outputs['" + POINTS_SHAPE_KEY + "'] = None" + System.lineSeparator() - + "else:" + System.lineSeparator() - + " task.outputs['" + POINTS_SHAPE_KEY + "'] = output[1]['" + POINTS_KEY + "'].shape" + System.lineSeparator() - + " task.outputs['"+ POINTS_DTYPE_KEY + "'] = output[1]['" + POINTS_KEY + "'].dtype" + System.lineSeparator() - + " points_shm = " - + " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['" + POINTS_KEY + "'].nbytes)" + System.lineSeparator() - + " shared_points = np.ndarray(output[1]['" + POINTS_KEY + "'].shape, dtype=output[1]['" + POINTS_KEY + "'].dtype, buffer=points_shm.buf)" + System.lineSeparator() - + " globals()['shared_points'] = shared_points" + System.lineSeparator() - + "if output[1]['" + COORDS_KEY + "'].nbytes == 0:" + System.lineSeparator() - + " task.outputs['" + COORDS_SHAPE_KEY + "'] = None" + System.lineSeparator() - + "else:" + System.lineSeparator() - + " task.outputs['" + COORDS_SHAPE_KEY + "'] = output[1]['" + COORDS_KEY + "'].shape" + System.lineSeparator() - + " task.outputs['" + COORDS_DTYPE_KEY + "'] = output[1]['" + COORDS_KEY + "'].dtype" + System.lineSeparator() - + " coords_shm = " - + " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['" + COORDS_KEY + "'].nbytes)" + System.lineSeparator() - + " shared_coords = np.ndarray(output[1]['" + COORDS_KEY + "'].shape, dtype=output[1]['" + COORDS_KEY + "'].dtype, buffer=coords_shm.buf)" + System.lineSeparator() - + " globals()['shared_coords'] = shared_coords" + System.lineSeparator() + + "if len(output) > 1 and type(output[1]) != dict:" + System.lineSeparator() + + " raise TypeError('If the StarDist output is a list, the second entry needs to be a dict.')" + System.lineSeparator() + + "task.outputs['" + KEYS_KEY + "'] = list(output[1].keys())" + System.lineSeparator() + + "shm_list = []" + System.lineSeparator() + + "np_list = []" + System.lineSeparator() + + "for kk, vv in output[1].items():" + System.lineSeparator() + + " print(kk)" + System.lineSeparator() + + " if type(vv) != np.ndarray:" + System.lineSeparator() + + " task.update('Output ' + kk + ' is not a np.ndarray. Only np.ndarrays supported.')" + System.lineSeparator() + + " print(type(vv))" + System.lineSeparator() + + " continue" + System.lineSeparator() + + " if output[1][kk].nbytes == 0:" + System.lineSeparator() + + " task.outputs[kk] = None" + System.lineSeparator() + + " else:" + System.lineSeparator() + + " task.outputs[kk + '" + SHAPE_KEY + "'] = output[1][kk].shape" + System.lineSeparator() + + " print(type(output[1][kk].shape))" + System.lineSeparator() + + " task.outputs[kk + '"+ DTYPE_KEY + "'] = str(output[1][kk].dtype)" + System.lineSeparator() + + " print(type(output[1][kk].dtype))" + System.lineSeparator() + + " shm = shared_memory.SharedMemory(create=True, size=output[1][kk].nbytes)" + System.lineSeparator() + + " task.outputs[kk + '"+ SHM_NAME_KEY + "'] = shm.name" + System.lineSeparator() + + " print(type(shm.name))" + System.lineSeparator() + + " shm_list.append(shm)" + System.lineSeparator() + + " aa = np.ndarray(output[1][kk].shape, dtype=output[1][kk].dtype, buffer=shm.buf)" + System.lineSeparator() + + " aa[:] = output[1][kk]" + System.lineSeparator() + + " np_list.append(aa)" + System.lineSeparator() + + "print('dd')" + System.lineSeparator() + + "globals()['shm_list'] = shm_list" + System.lineSeparator() + + "globals()['np_list'] = np_list" + System.lineSeparator() + + + "if os.name == 'nt':" + System.lineSeparator() + " im_shm.close()" + System.lineSeparator() + " im_shm.unlink()" + System.lineSeparator(); private static final String CLOSE_SHM_CODE = "" - + "if 'points_shm' in globals().keys():" + System.lineSeparator() - + " points_shm.close()" + System.lineSeparator() - + " points_shm.unlink()" + System.lineSeparator() - + "if 'coords_shm' in globals().keys():" + System.lineSeparator() - + " coords_shm.close()" + System.lineSeparator() - + " coords_shm.unlink()" + System.lineSeparator(); + + "if 'np_list' in globals().keys():" + System.lineSeparator() + + " for a in np_list:" + System.lineSeparator() + + " del a" + System.lineSeparator() + + "if 'shm_list' in globals().keys():" + System.lineSeparator() + + " for s in shm_list:" + System.lineSeparator() + + " s.unlink()" + System.lineSeparator() + + " del s" + System.lineSeparator(); protected abstract String createImportsCode(); @@ -265,11 +287,14 @@ Map> reconstructOutputs(Task task, String sh // TODO I do not understand why is complaining when the types align perfectly RandomAccessibleInterval maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()), Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI()))); + shma.close(); outs.put("mask", maskCopy); - outs.put(POINTS_KEY, reconstructPoints(task, shm_points_id)); - outs.put(COORDS_KEY, reconstructCoord(task, shm_coords_id)); - shma.close(); + if (task.outputs.get(KEYS_KEY) != null) { + for (String kk : (List) task.outputs.get(KEYS_KEY)) { + outs.put("", reconstruct(task, kk)); + } + } if (PlatformDetection.isWindows()) { Task closeSHMTask = python.task(CLOSE_SHM_CODE); @@ -279,17 +304,18 @@ Map> reconstructOutputs(Task task, String sh } private & NativeType> - RandomAccessibleInterval reconstructCoord(Task task, String shm_coords_id) throws IOException { - - String coords_dtype = (String) task.outputs.get(COORDS_DTYPE_KEY); - List coords_shape = (List) task.outputs.get(COORDS_SHAPE_KEY); + RandomAccessibleInterval reconstruct(Task task, String key) throws IOException { + + String shm_name = (String) task.outputs.get(key + SHM_NAME_KEY); + String coords_dtype = (String) task.outputs.get(key + DTYPE_KEY); + List coords_shape = (List) task.outputs.get(key + SHAPE_KEY); if (coords_shape == null) return null; long[] coordsSh = new long[coords_shape.size()]; for (int i = 0; i < coordsSh.length; i ++) coordsSh[i] = coords_shape.get(i).longValue(); - SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh, + SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_name, coordsSh, Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false); Map> outs = new HashMap>(); @@ -304,29 +330,6 @@ RandomAccessibleInterval reconstructCoord(Task task, String shm_coords_id) th return coordsCopy; } - private & NativeType> - RandomAccessibleInterval reconstructPoints(Task task, String shm_points_id) throws IOException { - - String points_dtype = (String) task.outputs.get(POINTS_DTYPE_KEY); - List points_shape = (List) task.outputs.get(POINTS_SHAPE_KEY); - if (points_shape == null) - return null; - - - long[] pointsSh = new long[points_shape.size()]; - for (int i = 0; i < pointsSh.length; i ++) - pointsSh[i] = points_shape.get(i).longValue(); - SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh, - Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false); - - // TODO I do not understand why is complaining when the types align perfectly - RandomAccessibleInterval pointsRAI = shmPoints.getSharedRAI(); - RandomAccessibleInterval pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI), - Util.getTypeFromInterval(Cast.unchecked(pointsRAI))); - shmPoints.close(); - return pointsCopy; - } - /** * Check whether everything that is needed for Stardist 2D is installed or not */