From 6b0d46f73404fe1c4e490917e4b2a488c7c8577d Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 15 Jan 2025 18:57:56 +0100 Subject: [PATCH] move towards full python api --- .../modelrunner/model/Stardist2D.java | 243 +++++------ .../modelrunner/model/Stardist2D_old.java | 396 ++++++++++++++++++ 2 files changed, 511 insertions(+), 128 deletions(-) create mode 100644 src/main/java/io/bioimage/modelrunner/model/Stardist2D_old.java diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java index c6009713..65f7ec51 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D.java @@ -38,8 +38,14 @@ import org.apache.commons.compress.archivers.ArchiveException; +import ai.nets.samj.install.EfficientSamEnvManager; +import ai.nets.samj.models.PythonMethods; +import io.bioimage.modelrunner.apposed.appose.Environment; import io.bioimage.modelrunner.apposed.appose.Mamba; import io.bioimage.modelrunner.apposed.appose.MambaInstallException; +import io.bioimage.modelrunner.apposed.appose.Service; +import io.bioimage.modelrunner.apposed.appose.Service.Task; +import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.bioimageio.BioimageioRepo; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; @@ -54,6 +60,8 @@ import io.bioimage.modelrunner.runmode.ops.GenericOp; import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.Utils; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; import io.bioimage.modelrunner.utils.Constants; import io.bioimage.modelrunner.utils.JSONUtils; import io.bioimage.modelrunner.versionmanagement.InstalledEngines; @@ -77,75 +85,135 @@ public class Stardist2D { private String modelDir; - private ModelDescriptor descriptor; + private final String name; - private final int channels; + private final String basedir; - private Float nms_threshold; + private boolean loaded = false; - private Float prob_threshold; + private SharedMemoryArray shma; - private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); + private ModelDescriptor descriptor; - private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); + private final int channels; - private static final String STARDIST2D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/"; + private Float nms_threshold; - private static final String STARDIST2D_SCRIPT_NAME= "stardist_postprocessing.py"; + private Float prob_threshold; - private static final String STARDIST2D_METHOD_NAME= "stardist_postprocessing"; + private Environment env; - private static final String THRES_FNAME = "thresholds.json"; + private Service python; - private static final String PROB_THRES_KEY = "thres"; + private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); - private static final String NMS_THRES_KEY = "thres"; + private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); - private static final float DEFAULT_NMS_THRES = (float) 0.4; + private static final String LOAD_MODEL_CODE = "" + + "if 'StarDist2D' not in globals().keys():" + System.lineSeparator() + + " from stardist.models import StarDist2D" + System.lineSeparator() + + " globals()['StarDist2D'] = StarDist2D" + System.lineSeparator() + + "if 'np' not in globals().keys():" + System.lineSeparator() + + " import numpy as np" + System.lineSeparator() + + " globals()['np'] = np" + System.lineSeparator() + + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + + " from multiprocessing import shared_memory" + System.lineSeparator() + + " globals()['shared_memory'] = shared_memory" + System.lineSeparator() + + "model = StarDist2D(None, name='%s', basedir='%s')" + System.lineSeparator() + + "globals()['model'] = model"; - private static final float DEFAULT_PROB_THRES = (float) 0.5; + private static final String RUN_MODEL_CODE = "" + + "shm_coords_id = task.inputs['shm_coords_id']" + System.lineSeparator() + + "shm_points_id = task.inputs['shm_points_id']" + System.lineSeparator() + + "output = model.predict_instances(im, returnPredcit=False)" + System.lineSeparator() + + "im[:] = output[0]" + System.lineSeparator() + + "task.outputs['coords_shape'] = output[1]['coords'].shape" + System.lineSeparator() + + "task.outputs['coords_dtype'] = output[1]['coords'].dtype" + System.lineSeparator() + + "task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator() + + "task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator() + + ""; - private Stardist2D(StardistConfig config, String modelName, String baseDir) { + private Stardist2D(String modelName, String baseDir) { + this.name = modelName; + this.basedir = baseDir; modelDir = new File(baseDir, modelName).getAbsolutePath(); - findWeights(); - findThresholds(); + if (new File(modelDir, "config.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false) + throw new IllegalArgumentException("No 'config.json' file found in the model directory"); + else if (new File(modelDir, "config.json").isFile() == false) + createConfigFromBioimageio(); + Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); + Map stardistConfig = (Map) stardistMap.get("config"); + Map stardistThres = (Map) stardistMap.get("thresholds"); + this.channels = (int) stardistConfig.get("n_channel_in");; + this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); + this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); - this.channels = 1; } - private void findWeights() { + private void createConfigFromBioimageio() { } - private void findThresholds() { - if (new File(modelDir, THRES_FNAME).isFile()) { - try { - Map json = JSONUtils.load(modelDir + File.separator + THRES_FNAME); - if (json.get(PROB_THRES_KEY) != null && json.get(PROB_THRES_KEY) instanceof Number) - prob_threshold = ((Number) json.get(PROB_THRES_KEY)).floatValue(); - if (json.get(NMS_THRES_KEY) != null && json.get(NMS_THRES_KEY) instanceof Number) - nms_threshold = ((Number) json.get(NMS_THRES_KEY)).floatValue(); - } catch (IOException e) { - } - } - if (nms_threshold == null) { - System.out.println("Nms threshold not defined, using default value: " + DEFAULT_NMS_THRES); - nms_threshold = DEFAULT_NMS_THRES; - } - if (prob_threshold == null) { - System.out.println("Probability threshold not defined, using default value: " + DEFAULT_PROB_THRES); - prob_threshold = DEFAULT_NMS_THRES; - } + private void loadModel() throws IOException, InterruptedException { + if (loaded) + return; + String code = String.format(LOAD_MODEL_CODE, this.name, this.basedir); + Task task = python.task(code); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException("Task canceled"); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(task.error); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(task.error); + loaded = true; } - private Stardist2D(ModelDescriptor descriptor) { - this.descriptor = descriptor; - Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); - Map stardistConfig = (Map) stardistMap.get("config"); - Map stardistThres = (Map) stardistMap.get("thresholds"); - this.channels = (int) stardistConfig.get("n_channel_in");; - this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); - this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); + + protected String createEncodeImageScript() { + String code = ""; + // This line wants to recreate the original numpy array. Should look like: + // input0_appose_shm = shared_memory.SharedMemory(name=input0) + // input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64]) + code += "im_shm = shared_memory.SharedMemory(name='" + + shma.getNameForPython() + "', size=" + shma.getSize() + + ")" + System.lineSeparator(); + code += "im = np.ndarray(" + shma.getSize() + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI())) + + "', buffer=im_shm.buf).reshape(["; + for (int i = 0; i < shma.getOriginalShape().length; i ++) + code += shma.getOriginalShape()[i] + ", "; + code += "])" + System.lineSeparator(); + return code; + } + + public void close() { + if (!loaded) + return; + python.close(); + } + + public & NativeType> void run(RandomAccessibleInterval img) throws IOException, InterruptedException { + + shma = SharedMemoryArray.createSHMAFromRAI(img); + String code = ""; + if (!loaded) { + code += String.format(LOAD_MODEL_CODE, this.name, this.basedir) + System.lineSeparator(); + } + + code += createEncodeImageScript() + System.lineSeparator(); + code += RUN_MODEL_CODE + System.lineSeparator(); + + Task task = python.task(code); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException("Task canceled"); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(task.error); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(task.error); + task.outputs.get(""); + + } /** @@ -159,7 +227,7 @@ private Stardist2D(ModelDescriptor descriptor) { */ public static Stardist2D fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException { ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME); - return new Stardist2D(descriptor); + return new Stardist2D(modelPath); } /** @@ -233,87 +301,6 @@ else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC."); } - /** - * Run the Stardist 2D model end to end, including pre- and post-processing. - * @param - * possible ImgLib2 data types of the input and output images - * @param image - * the input image that is going to be processed by Stardist2D - * @return the final output of Stardist2D including pre- and post-processing - * @throws ModelSpecsException if there is any error with the specs of the model - * @throws LoadModelException if there is any error loading the model in Tensorflow Java - * @throws LoadEngineException if there is any error loading Tensorflow Java engine - * @throws IOException if there is any error with the files that are required to run the model - * @throws RunModelException if there is any unexpected exception running the post-processing - * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly - */ - public & NativeType> - RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException, - LoadEngineException, IOException, - RunModelException, InterruptedException { - checkInput(image); - if (image.dimensionsAsLongArray().length == 2) image = Views.addDimension(image, 0, 0); - image = Views.permute(image, 0, 2); - image = Views.addDimension(image, 0, 0); - image = Views.permute(image, 0, 3); - - Tensor inputTensor = Tensor.build("input", "byxc", image); - Tensor outputTensor = Tensor.buildEmptyTensor("output", "byxc"); - - List> inputList = new ArrayList>(); - List> outputList = new ArrayList>(); - inputList.add(inputTensor); - outputList.add(outputTensor); - - Model model = Model.createBioimageioModel(this.descriptor.getModelPath()); - model.loadModel(); - Processing processing = Processing.init(descriptor); - inputList = processing.preprocess(inputList, false); - model.runModel(inputList, outputList); - - return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData()))); - } - - /** - * Execute stardist post-processing on the raw output of a Stardist 2D model - * @param - * possible data type of the input image - * @param image - * the raw output of a Stardist 2D model - * @return the final output of a Stardist 2D model - * @throws IOException if there is any error running the post-processing - * @throws InterruptedException if the post-processing is interrupted - */ - public & NativeType> - RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException { - Mamba mamba = new Mamba(); - String envPath = mamba.getEnvsDir() + File.separator + "stardist"; - String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME; - - GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST2D_METHOD_NAME, 1); - LinkedHashMap nMap = new LinkedHashMap(); - Calendar cal = Calendar.getInstance(); - SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss"); - String dateString = sdf.format(cal.getTime()); - nMap.put("input_" + dateString, image); - nMap.put("nms_thresh", nms_threshold); - nMap.put("prob_thresh", prob_threshold); - op.setInputs(nMap); - - RunMode rm; - rm = RunMode.createRunMode(op); - Map resMap = rm.runOP(); - - List> rais = resMap.entrySet().stream() - .filter(e -> { - Object val = e.getValue(); - if (val instanceof RandomAccessibleInterval) return true; - return false; - }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList()); - - return rais.get(0); - } - /** * Check whether everything that is needed for Stardist 2D is installed or not */ diff --git a/src/main/java/io/bioimage/modelrunner/model/Stardist2D_old.java b/src/main/java/io/bioimage/modelrunner/model/Stardist2D_old.java new file mode 100644 index 00000000..cd3d6395 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/Stardist2D_old.java @@ -0,0 +1,396 @@ +/*- + * #%L + * Use deep learning frameworks from Java in an agnostic and isolated way. + * %% + * Copyright (C) 2022 - 2024 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.model; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.compress.archivers.ArchiveException; + +import io.bioimage.modelrunner.apposed.appose.Mamba; +import io.bioimage.modelrunner.apposed.appose.MambaInstallException; +import io.bioimage.modelrunner.bioimageio.BioimageioRepo; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; +import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException; +import io.bioimage.modelrunner.engine.installation.EngineInstall; +import io.bioimage.modelrunner.exceptions.LoadEngineException; +import io.bioimage.modelrunner.exceptions.LoadModelException; +import io.bioimage.modelrunner.exceptions.RunModelException; +import io.bioimage.modelrunner.model.processing.Processing; +import io.bioimage.modelrunner.model.stardist_java_deprecate.StardistConfig; +import io.bioimage.modelrunner.runmode.RunMode; +import io.bioimage.modelrunner.runmode.ops.GenericOp; +import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.tensor.Utils; +import io.bioimage.modelrunner.utils.Constants; +import io.bioimage.modelrunner.utils.JSONUtils; +import io.bioimage.modelrunner.versionmanagement.InstalledEngines; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; +import net.imglib2.view.Views; + +/** + * @deprecated + * Implementation of an API to run Stardist 2D models out of the box with little configuration. + * + *TODO add fine tuning + *TODO add support for Mac arm + * + *@author Carlos Garcia + */ +public class Stardist2D_old { + + private String modelDir; + + private ModelDescriptor descriptor; + + private final int channels; + + private Float nms_threshold; + + private Float prob_threshold; + + private static final List STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"}); + + private static final List STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"}); + + private static final String STARDIST2D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/"; + + private static final String STARDIST2D_SCRIPT_NAME= "stardist_postprocessing.py"; + + private static final String STARDIST2D_METHOD_NAME= "stardist_postprocessing"; + + private static final String THRES_FNAME = "thresholds.json"; + + private static final String PROB_THRES_KEY = "thres"; + + private static final String NMS_THRES_KEY = "thres"; + + private static final float DEFAULT_NMS_THRES = (float) 0.4; + + private static final float DEFAULT_PROB_THRES = (float) 0.5; + + private Stardist2D_old(StardistConfig config, String modelName, String baseDir) { + modelDir = new File(baseDir, modelName).getAbsolutePath(); + findWeights(); + findThresholds(); + + this.channels = 1; + } + + private void findWeights() { + + } + + private void findThresholds() { + if (new File(modelDir, THRES_FNAME).isFile()) { + try { + Map json = JSONUtils.load(modelDir + File.separator + THRES_FNAME); + if (json.get(PROB_THRES_KEY) != null && json.get(PROB_THRES_KEY) instanceof Number) + prob_threshold = ((Number) json.get(PROB_THRES_KEY)).floatValue(); + if (json.get(NMS_THRES_KEY) != null && json.get(NMS_THRES_KEY) instanceof Number) + nms_threshold = ((Number) json.get(NMS_THRES_KEY)).floatValue(); + } catch (IOException e) { + } + } + if (nms_threshold == null) { + System.out.println("Nms threshold not defined, using default value: " + DEFAULT_NMS_THRES); + nms_threshold = DEFAULT_NMS_THRES; + } + if (prob_threshold == null) { + System.out.println("Probability threshold not defined, using default value: " + DEFAULT_PROB_THRES); + prob_threshold = DEFAULT_NMS_THRES; + } + } + + private Stardist2D_old(ModelDescriptor descriptor) { + this.descriptor = descriptor; + Map stardistMap = (Map) descriptor.getConfig().getSpecMap().get("stardist"); + Map stardistConfig = (Map) stardistMap.get("config"); + Map stardistThres = (Map) stardistMap.get("thresholds"); + this.channels = (int) stardistConfig.get("n_channel_in");; + this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue(); + this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue(); + } + + /** + * Initialize a Stardist2D using the format of the Bioiamge.io model zoo. + * @param modelPath + * path to the Bioimage.io model + * @return an instance of a Stardist2D model ready to be used + * @throws ModelSpecsException If there is any error in the configuration of the specs rdf.yaml file of the Bioimage.io. + * @throws FileNotFoundException If the model file is not found. + * @throws IOException If there's an I/O error. + */ + public static Stardist2D_old fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException { + ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME); + return new Stardist2D_old(descriptor); + } + + /** + * Initialize one of the "official" pretrained Stardist 2D models. + * By default, the model will be installed in the "models" folder inside the application + * @param pretrainedModel + * the name of the pretrained model. + * @param forceInstall + * whether to force the installation or to try to look if the model has already been installed before + * @return an instance of a pretrained Stardist2D model ready to be used + * @throws IOException if there is any error downloading the model, in the case it is needed + * @throws InterruptedException if the download of the model is stopped + * @throws ModelSpecsException if the model downloaded is not well specified in the config file + */ + public static Stardist2D_old fromPretained(String pretrainedModel, boolean forceInstall) throws IOException, InterruptedException, ModelSpecsException { + return fromPretained(pretrainedModel, new File("models").getAbsolutePath(), forceInstall); + } + + /** + * TODO add support for 2D_paper_dsb2018 + * Initialize one of the "official" pretrained Stardist 2D models + * @param pretrainedModel + * the name of the pretrained model. + * @param installDir + * the directory where the model wants to be installed + * @param forceInstall + * whether to force the installation or to try to look if the model has already been installed before + * @return an instance of a pretrained Stardist2D model ready to be used + * @throws IOException if there is any error downloading the model, in the case it is needed + * @throws InterruptedException if the download of the model is stopped + * @throws ModelSpecsException if the model downloaded is not well specified in the config file + */ + public static Stardist2D_old fromPretained(String pretrainedModel, String installDir, boolean forceInstall) throws IOException, + InterruptedException, + ModelSpecsException { + if ((pretrainedModel.equals("StarDist H&E Nuclei Segmentation") + || pretrainedModel.equals("2D_versatile_he")) && !forceInstall) { + ModelDescriptor md = ModelDescriptorFactory.getModelsAtLocalRepo().stream() + .filter(mm ->mm.getName().equals("StarDist H&E Nuclei Segmentation")).findFirst().orElse(null); + if (md != null) return new Stardist2D_old(md); + String path = BioimageioRepo.connect().downloadByName("StarDist H&E Nuclei Segmentation", installDir); + return Stardist2D_old.fromBioimageioModel(path); + } else if (pretrainedModel.equals("StarDist H&E Nuclei Segmentation") + || pretrainedModel.equals("2D_versatile_he")) { + String path = BioimageioRepo.connect().downloadByName("StarDist H&E Nuclei Segmentation", installDir); + return Stardist2D_old.fromBioimageioModel(path); + } else if ((pretrainedModel.equals("StarDist Fluorescence Nuclei Segmentation") + || pretrainedModel.equals("2D_versatile_fluo")) && !forceInstall) { + ModelDescriptor md = ModelDescriptorFactory.getModelsAtLocalRepo().stream() + .filter(mm ->mm.getName().equals("StarDist Fluorescence Nuclei Segmentation")).findFirst().orElse(null); + if (md != null) return new Stardist2D_old(md); + String path = BioimageioRepo.connect().downloadByName("StarDist Fluorescence Nuclei Segmentation", installDir); + return Stardist2D_old.fromBioimageioModel(path); + } else if (pretrainedModel.equals("StarDist Fluorescence Nuclei Segmentation") + || pretrainedModel.equals("2D_versatile_fluo")) { + String path = BioimageioRepo.connect().downloadByName("StarDist Fluorescence Nuclei Segmentation", installDir); + return Stardist2D_old.fromBioimageioModel(path); + } else { + throw new IllegalArgumentException("There is no Stardist2D model called: " + pretrainedModel); + } + } + + private & NativeType> void checkInput(RandomAccessibleInterval image) { + if (image.dimensionsAsLongArray().length == 2 && this.channels != 1) + throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC"); + else if (image.dimensionsAsLongArray().length != 3 && this.channels != 1) + throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC"); + else if (image.dimensionsAsLongArray().length != 2 && image.dimensionsAsLongArray()[2] != channels) + throw new IllegalArgumentException("This Stardist2D model requires " + channels + " channels."); + else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray().length < 2) + throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC."); + } + + /** + * Run the Stardist 2D model end to end, including pre- and post-processing. + * @param + * possible ImgLib2 data types of the input and output images + * @param image + * the input image that is going to be processed by Stardist2D + * @return the final output of Stardist2D including pre- and post-processing + * @throws ModelSpecsException if there is any error with the specs of the model + * @throws LoadModelException if there is any error loading the model in Tensorflow Java + * @throws LoadEngineException if there is any error loading Tensorflow Java engine + * @throws IOException if there is any error with the files that are required to run the model + * @throws RunModelException if there is any unexpected exception running the post-processing + * @throws InterruptedException if the inference or post-processing are interrupted unexpectedly + */ + public & NativeType> + RandomAccessibleInterval predict(RandomAccessibleInterval image) throws ModelSpecsException, LoadModelException, + LoadEngineException, IOException, + RunModelException, InterruptedException { + checkInput(image); + if (image.dimensionsAsLongArray().length == 2) image = Views.addDimension(image, 0, 0); + image = Views.permute(image, 0, 2); + image = Views.addDimension(image, 0, 0); + image = Views.permute(image, 0, 3); + + Tensor inputTensor = Tensor.build("input", "byxc", image); + Tensor outputTensor = Tensor.buildEmptyTensor("output", "byxc"); + + List> inputList = new ArrayList>(); + List> outputList = new ArrayList>(); + inputList.add(inputTensor); + outputList.add(outputTensor); + + Model model = Model.createBioimageioModel(this.descriptor.getModelPath()); + model.loadModel(); + Processing processing = Processing.init(descriptor); + inputList = processing.preprocess(inputList, false); + model.runModel(inputList, outputList); + + return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData()))); + } + + /** + * Execute stardist post-processing on the raw output of a Stardist 2D model + * @param + * possible data type of the input image + * @param image + * the raw output of a Stardist 2D model + * @return the final output of a Stardist 2D model + * @throws IOException if there is any error running the post-processing + * @throws InterruptedException if the post-processing is interrupted + */ + public & NativeType> + RandomAccessibleInterval postProcessing(RandomAccessibleInterval image) throws IOException, InterruptedException { + Mamba mamba = new Mamba(); + String envPath = mamba.getEnvsDir() + File.separator + "stardist"; + String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME; + + GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST2D_METHOD_NAME, 1); + LinkedHashMap nMap = new LinkedHashMap(); + Calendar cal = Calendar.getInstance(); + SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss"); + String dateString = sdf.format(cal.getTime()); + nMap.put("input_" + dateString, image); + nMap.put("nms_thresh", nms_threshold); + nMap.put("prob_thresh", prob_threshold); + op.setInputs(nMap); + + RunMode rm; + rm = RunMode.createRunMode(op); + Map resMap = rm.runOP(); + + List> rais = resMap.entrySet().stream() + .filter(e -> { + Object val = e.getValue(); + if (val instanceof RandomAccessibleInterval) return true; + return false; + }).map(e -> (RandomAccessibleInterval) e.getValue()).collect(Collectors.toList()); + + return rais.get(0); + } + + /** + * Check whether everything that is needed for Stardist 2D is installed or not + */ + public void checkRequirementsInstalled() { + // TODO + } + + /** + * Check whether the requirements needed to run Stardist 2D are satisfied or not. + * First checks if the corresponding Java DL engine is installed or not, then checks + * if the Python environment needed for Stardist2D post processing is fine too. + * + * If anything is not installed, this method also installs it + * + * @throws IOException if there is any error downloading the DL engine or installing the micromamba environment + * @throws InterruptedException if the installation is stopped + * @throws RuntimeException if there is any unexpected error in the micromamba environment installation + * @throws MambaInstallException if there is any error downloading or installing micromamba + * @throws ArchiveException if there is any error decompressing the micromamba installer + * @throws URISyntaxException if the URL to the micromamba installation is not correct + */ + public static void installRequirements() throws IOException, InterruptedException, + RuntimeException, MambaInstallException, + ArchiveException, URISyntaxException { + boolean installed = InstalledEngines.buildEnginesFinder() + .checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0; + if (!installed) + EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true); + + Mamba mamba = new Mamba(); + boolean stardistPythonInstalled = false; + try { + stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS); + } catch (MambaInstallException e) { + mamba.installMicromamba(); + } + if (!stardistPythonInstalled) { + // TODO add logging for environment installation + mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS); + }; + String envPath = mamba.getEnvsDir() + File.separator + "stardist"; + String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME; + if (!Paths.get(scriptPath).toFile().isFile()) { + try (InputStream scriptStream = Stardist2D_old.class.getClassLoader() + .getResourceAsStream(STARDIST2D_PATH_IN_RESOURCES + STARDIST2D_SCRIPT_NAME)){ + Files.copy(scriptStream, Paths.get(scriptPath), StandardCopyOption.REPLACE_EXISTING); + } + } + } + + /** + * Main method to check functionality + * @param args + * nothing + * @throws IOException nothing + * @throws InterruptedException nothing + * @throws RuntimeException nothing + * @throws MambaInstallException nothing + * @throws ModelSpecsException nothing + * @throws LoadEngineException nothing + * @throws RunModelException nothing + * @throws ArchiveException nothing + * @throws URISyntaxException nothing + * @throws LoadModelException nothing + */ + public static void main(String[] args) throws IOException, InterruptedException, + RuntimeException, MambaInstallException, + ModelSpecsException, LoadEngineException, + RunModelException, ArchiveException, + URISyntaxException, LoadModelException { + Stardist2D_old.installRequirements(); + Stardist2D_old model = Stardist2D_old.fromPretained("2D_versatile_fluo", false); + + RandomAccessibleInterval img = ArrayImgs.floats(new long[] {512, 512}); + + RandomAccessibleInterval res = model.predict(img); + System.out.println(true); + } +}