Skip to content

Commit

Permalink
move towards full python api
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 15, 2025
1 parent 6acc18e commit 6b0d46f
Show file tree
Hide file tree
Showing 2 changed files with 511 additions and 128 deletions.
243 changes: 115 additions & 128 deletions src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@

import org.apache.commons.compress.archivers.ArchiveException;

import ai.nets.samj.install.EfficientSamEnvManager;
import ai.nets.samj.models.PythonMethods;
import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
Expand All @@ -54,6 +60,8 @@
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.JSONUtils;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
Expand All @@ -77,75 +85,135 @@ public class Stardist2D {

private String modelDir;

private ModelDescriptor descriptor;
private final String name;

private final int channels;
private final String basedir;

private Float nms_threshold;
private boolean loaded = false;

private Float prob_threshold;
private SharedMemoryArray shma;

private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});
private ModelDescriptor descriptor;

private static final List<String> STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});
private final int channels;

private static final String STARDIST2D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/";
private Float nms_threshold;

private static final String STARDIST2D_SCRIPT_NAME= "stardist_postprocessing.py";
private Float prob_threshold;

private static final String STARDIST2D_METHOD_NAME= "stardist_postprocessing";
private Environment env;

private static final String THRES_FNAME = "thresholds.json";
private Service python;

private static final String PROB_THRES_KEY = "thres";
private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});

private static final String NMS_THRES_KEY = "thres";
private static final List<String> STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});

private static final float DEFAULT_NMS_THRES = (float) 0.4;
private static final String LOAD_MODEL_CODE = ""
+ "if 'StarDist2D' not in globals().keys():" + System.lineSeparator()
+ " from stardist.models import StarDist2D" + System.lineSeparator()
+ " globals()['StarDist2D'] = StarDist2D" + System.lineSeparator()
+ "if 'np' not in globals().keys():" + System.lineSeparator()
+ " import numpy as np" + System.lineSeparator()
+ " globals()['np'] = np" + System.lineSeparator()
+ "if 'shared_memory' not in globals().keys():" + System.lineSeparator()
+ " from multiprocessing import shared_memory" + System.lineSeparator()
+ " globals()['shared_memory'] = shared_memory" + System.lineSeparator()
+ "model = StarDist2D(None, name='%s', basedir='%s')" + System.lineSeparator()
+ "globals()['model'] = model";

private static final float DEFAULT_PROB_THRES = (float) 0.5;
private static final String RUN_MODEL_CODE = ""
+ "shm_coords_id = task.inputs['shm_coords_id']" + System.lineSeparator()
+ "shm_points_id = task.inputs['shm_points_id']" + System.lineSeparator()
+ "output = model.predict_instances(im, returnPredcit=False)" + System.lineSeparator()
+ "im[:] = output[0]" + System.lineSeparator()
+ "task.outputs['coords_shape'] = output[1]['coords'].shape" + System.lineSeparator()
+ "task.outputs['coords_dtype'] = output[1]['coords'].dtype" + System.lineSeparator()
+ "task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator()
+ "task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator()
+ "";

private Stardist2D(StardistConfig config, String modelName, String baseDir) {
private Stardist2D(String modelName, String baseDir) {
this.name = modelName;
this.basedir = baseDir;
modelDir = new File(baseDir, modelName).getAbsolutePath();
findWeights();
findThresholds();
if (new File(modelDir, "config.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false)
throw new IllegalArgumentException("No 'config.json' file found in the model directory");
else if (new File(modelDir, "config.json").isFile() == false)
createConfigFromBioimageio();
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
this.channels = (int) stardistConfig.get("n_channel_in");;
this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue();
this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue();

this.channels = 1;
}

private void findWeights() {
private void createConfigFromBioimageio() {

}

private void findThresholds() {
if (new File(modelDir, THRES_FNAME).isFile()) {
try {
Map<String, Object> json = JSONUtils.load(modelDir + File.separator + THRES_FNAME);
if (json.get(PROB_THRES_KEY) != null && json.get(PROB_THRES_KEY) instanceof Number)
prob_threshold = ((Number) json.get(PROB_THRES_KEY)).floatValue();
if (json.get(NMS_THRES_KEY) != null && json.get(NMS_THRES_KEY) instanceof Number)
nms_threshold = ((Number) json.get(NMS_THRES_KEY)).floatValue();
} catch (IOException e) {
}
}
if (nms_threshold == null) {
System.out.println("Nms threshold not defined, using default value: " + DEFAULT_NMS_THRES);
nms_threshold = DEFAULT_NMS_THRES;
}
if (prob_threshold == null) {
System.out.println("Probability threshold not defined, using default value: " + DEFAULT_PROB_THRES);
prob_threshold = DEFAULT_NMS_THRES;
}
private void loadModel() throws IOException, InterruptedException {
if (loaded)
return;
String code = String.format(LOAD_MODEL_CODE, this.name, this.basedir);
Task task = python.task(code);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException("Task canceled");
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException(task.error);
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException(task.error);
loaded = true;
}

private Stardist2D(ModelDescriptor descriptor) {
this.descriptor = descriptor;
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
this.channels = (int) stardistConfig.get("n_channel_in");;
this.nms_threshold = new Double((double) stardistThres.get("nms")).floatValue();
this.prob_threshold = new Double((double) stardistThres.get("prob")).floatValue();

protected String createEncodeImageScript() {
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
code += "im = np.ndarray(" + shma.getSize() + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (int i = 0; i < shma.getOriginalShape().length; i ++)
code += shma.getOriginalShape()[i] + ", ";
code += "])" + System.lineSeparator();
return code;
}

public void close() {
if (!loaded)
return;
python.close();
}

public <T extends RealType<T> & NativeType<T>> void run(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {

shma = SharedMemoryArray.createSHMAFromRAI(img);
String code = "";
if (!loaded) {
code += String.format(LOAD_MODEL_CODE, this.name, this.basedir) + System.lineSeparator();
}

code += createEncodeImageScript() + System.lineSeparator();
code += RUN_MODEL_CODE + System.lineSeparator();

Task task = python.task(code);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException("Task canceled");
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException(task.error);
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException(task.error);
task.outputs.get("");


}

/**
Expand All @@ -159,7 +227,7 @@ private Stardist2D(ModelDescriptor descriptor) {
*/
public static Stardist2D fromBioimageioModel(String modelPath) throws ModelSpecsException, FileNotFoundException, IOException {
ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(modelPath + File.separator + Constants.RDF_FNAME);
return new Stardist2D(descriptor);
return new Stardist2D(modelPath);
}

/**
Expand Down Expand Up @@ -233,87 +301,6 @@ else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray
throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC.");
}

/**
* Run the Stardist 2D model end to end, including pre- and post-processing.
* @param <T>
* possible ImgLib2 data types of the input and output images
* @param image
* the input image that is going to be processed by Stardist2D
* @return the final output of Stardist2D including pre- and post-processing
* @throws ModelSpecsException if there is any error with the specs of the model
* @throws LoadModelException if there is any error loading the model in Tensorflow Java
* @throws LoadEngineException if there is any error loading Tensorflow Java engine
* @throws IOException if there is any error with the files that are required to run the model
* @throws RunModelException if there is any unexpected exception running the post-processing
* @throws InterruptedException if the inference or post-processing are interrupted unexpectedly
*/
public <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> predict(RandomAccessibleInterval<T> image) throws ModelSpecsException, LoadModelException,
LoadEngineException, IOException,
RunModelException, InterruptedException {
checkInput(image);
if (image.dimensionsAsLongArray().length == 2) image = Views.addDimension(image, 0, 0);
image = Views.permute(image, 0, 2);
image = Views.addDimension(image, 0, 0);
image = Views.permute(image, 0, 3);

Tensor<T> inputTensor = Tensor.build("input", "byxc", image);
Tensor<T> outputTensor = Tensor.buildEmptyTensor("output", "byxc");

List<Tensor<T>> inputList = new ArrayList<Tensor<T>>();
List<Tensor<T>> outputList = new ArrayList<Tensor<T>>();
inputList.add(inputTensor);
outputList.add(outputTensor);

Model model = Model.createBioimageioModel(this.descriptor.getModelPath());
model.loadModel();
Processing processing = Processing.init(descriptor);
inputList = processing.preprocess(inputList, false);
model.runModel(inputList, outputList);

return Utils.transpose(Cast.unchecked(postProcessing(outputList.get(0).getData())));
}

/**
* Execute stardist post-processing on the raw output of a Stardist 2D model
* @param <T>
* possible data type of the input image
* @param image
* the raw output of a Stardist 2D model
* @return the final output of a Stardist 2D model
* @throws IOException if there is any error running the post-processing
* @throws InterruptedException if the post-processing is interrupted
*/
public <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> postProcessing(RandomAccessibleInterval<T> image) throws IOException, InterruptedException {
Mamba mamba = new Mamba();
String envPath = mamba.getEnvsDir() + File.separator + "stardist";
String scriptPath = envPath + File.separator + STARDIST2D_SCRIPT_NAME;

GenericOp op = GenericOp.create(envPath, scriptPath, STARDIST2D_METHOD_NAME, 1);
LinkedHashMap<String, Object> nMap = new LinkedHashMap<String, Object>();
Calendar cal = Calendar.getInstance();
SimpleDateFormat sdf = new SimpleDateFormat("ddMMYYYY_HHmmss");
String dateString = sdf.format(cal.getTime());
nMap.put("input_" + dateString, image);
nMap.put("nms_thresh", nms_threshold);
nMap.put("prob_thresh", prob_threshold);
op.setInputs(nMap);

RunMode rm;
rm = RunMode.createRunMode(op);
Map<String, Object> resMap = rm.runOP();

List<RandomAccessibleInterval<T>> rais = resMap.entrySet().stream()
.filter(e -> {
Object val = e.getValue();
if (val instanceof RandomAccessibleInterval) return true;
return false;
}).map(e -> (RandomAccessibleInterval<T>) e.getValue()).collect(Collectors.toList());

return rais.get(0);
}

/**
* Check whether everything that is needed for Stardist 2D is installed or not
*/
Expand Down
Loading

0 comments on commit 6b0d46f

Please sign in to comment.