Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional prediction mask to pixel classification #67

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ public PixelClassification(File executableFilePath, File projectFileName, LogSer
}

public <T extends NativeType<T>> ImgPlus<T> classifyPixels(ImgPlus<? extends RealType<?>> rawInputImg,
ImgPlus<? extends RealType<?>> predictionMask,
PixelPredictionType pixelPredictionType) throws IOException {
return executeIlastik(rawInputImg, null, pixelPredictionType);
return executeIlastik(rawInputImg, predictionMask, pixelPredictionType);
}

@Override
Expand All @@ -36,6 +37,10 @@ else if (pixelPredictionType == PixelPredictionType.Probabilities) {
commandLine.add("--raw_data=" + tempFiles.get(rawInputTempFile));
commandLine.add("--output_filename_format=" + tempFiles.get(outputTempFile));

if(tempFiles.containsKey(secondInputTempFile)){
commandLine.add("--prediction_mask=" + tempFiles.get(secondInputTempFile));
}

return commandLine;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package org.ilastik.ilastik4ij.ui;

import ij.Macro;
import ij.plugin.frame.Recorder;
import net.imagej.Dataset;
import net.imagej.DatasetService;
import net.imagej.ImgPlus;
import net.imglib2.type.NativeType;
import org.ilastik.ilastik4ij.executors.PixelClassification;
Expand Down Expand Up @@ -31,46 +34,87 @@ public class IlastikPixelClassificationCommand implements Command {
public OptionsService optionsService;

@Parameter
public UIService uiService;
public DatasetService datasetService;

@Parameter(label = "Trained ilastik project file")
public File projectFileName;
@Parameter
public UIService uiService;

@Parameter(label = "Raw input image")
public Dataset inputImage;

@Parameter(label = "Output type", choices = {UiConstants.PIXEL_PREDICTION_TYPE_PROBABILITIES, UiConstants.PIXEL_PREDICTION_TYPE_SEGMENTATION}, style = "radioButtonHorizontal")
public String pixelClassificationType;

@Parameter(type = ItemIO.OUTPUT)
private ImgPlus<? extends NativeType<?>> predictions;

public IlastikOptions ilastikOptions;

/**
* Run method that calls ilastik
*/
private IlastikPixelClassificationModel createModel() {
IlastikPixelClassificationModel model = new IlastikPixelClassificationModel(logService);
model.setRawInput(inputImage);

String macroOptions = Macro.getOptions();

if (macroOptions != null) {
String projectFilePath = Macro.getValue(macroOptions, "projectfilename", "");
model.setIlastikProjectFile(new File(projectFilePath));
model.setOutputType(Macro.getValue(macroOptions, "pixelclassificationtype", ""));

String predictionMaskName = Macro.getValue(macroOptions, "predictionmask", "");
for (Dataset ds : datasetService.getDatasets()) {
if (ds.getName().equals(predictionMaskName)) {
model.setPredictionMask(ds);
break;
}
}
}

return model;
}

@Override
public void run() {

if (ilastikOptions == null)
ilastikOptions = optionsService.getOptions(IlastikOptions.class);

IlastikPixelClassificationModel model = createModel();
if (!model.isValid()) {
IlastikPixelClassificationDialog dialog = new IlastikPixelClassificationDialog(logService, uiService, datasetService, model);
model.fireInitialProperties();
dialog.setVisible(true);
if (dialog.wasCancelled()) {
return;
}
}

try {
runClassification();
runClassification(model);
if (Recorder.record) {
Recorder.recordOption("projectfilename", model.getIlastikProjectFile().getAbsolutePath());
Recorder.recordOption("pixelclassificationtype", model.getOutputType());
if (model.getPredictionMask() != null) {
Recorder.recordOption("predictionmask", model.getPredictionMask().getName());
}
}
} catch (IOException e) {
logService.error("Pixel classification command failed", e);
throw new RuntimeException(e);
}
}
}

private void runClassification() throws IOException {
private void runClassification(IlastikPixelClassificationModel model) throws IOException {
final PixelClassification pixelClassification = new PixelClassification(ilastikOptions.getExecutableFile(),
projectFileName, logService, statusService, ilastikOptions.getNumThreads(), ilastikOptions.getMaxRamMb());
model.getIlastikProjectFile(), logService, statusService, ilastikOptions.getNumThreads(), ilastikOptions.getMaxRamMb());

PixelPredictionType pixelPredictionType = PixelPredictionType.valueOf(pixelClassificationType);
this.predictions = pixelClassification.classifyPixels(inputImage.getImgPlus(), pixelPredictionType);
PixelPredictionType pixelPredictionType = PixelPredictionType.valueOf(model.getOutputType());

ImgPlus predMaskImg = null;
if (model.getPredictionMask() != null) {
predMaskImg = model.getPredictionMask().getImgPlus();
}

// DisplayUtils.showOutput(uiService, predictions);
this.predictions = pixelClassification.classifyPixels(
inputImage.getImgPlus(),
predMaskImg,
pixelPredictionType
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package org.ilastik.ilastik4ij.ui;

import net.imagej.Dataset;
import net.imagej.DatasetService;
import org.scijava.log.LogService;
import org.scijava.ui.UIService;
import org.scijava.widget.FileWidget;

import javax.print.DocFlavor;
import javax.swing.*;
import javax.swing.border.Border;
import javax.swing.border.LineBorder;
import javax.swing.event.DocumentEvent;
import javax.swing.event.DocumentListener;
import java.awt.*;
import java.beans.PropertyChangeEvent;
import java.beans.PropertyChangeListener;
import java.io.File;

public class IlastikPixelClassificationDialog extends JDialog implements PropertyChangeListener {
private static final Border VALID_BORDER = new JTextField().getBorder();
private static final Border INVALID_BORDER = new LineBorder(Color.RED, 1);

private class DatasetComboboxEntry {
public final Dataset dataset;
public final String title;

public DatasetComboboxEntry(String title, Dataset dataset) {
this.title = title;
this.dataset = dataset;
}

@Override
public String toString() {
return this.title;
}
}
private class OutputTypeComboboxEntry {
public final String verboseName;
public final String type;

public OutputTypeComboboxEntry(String verboseName, String type) {
this.verboseName = verboseName;
this.type = type;
}

@Override
public String toString() {
return this.verboseName;
}
}

private boolean cancelled = true;
private final LogService logService;
private final UIService uiService;
private final DatasetService datasetService;
private final IlastikPixelClassificationModel model;

private final JPanel contentPanel = new JPanel();
private final JPanel controlPanel = new JPanel(new FlowLayout(FlowLayout.RIGHT));

private final JLabel ilpPathLabel = new JLabel("Path:");
private final JTextField ilpPath = new JTextField();
private final JButton ilpPathBrowse = new JButton("Browse");

private final JLabel outputTypeLabel = new JLabel("Output type:");
private final JComboBox<String> outputType = new JComboBox<>();

private final JLabel predictionMaskLabel = new JLabel("Prediction Mask:");
private final JComboBox<DatasetComboboxEntry> predictionMask = new JComboBox<>();

private final JButton predictBtn = new JButton("Predict");
private final JButton cancelBtn = new JButton("Cancel");

private void initializeComponentLayout() {
getContentPane().setLayout(new BorderLayout());
GroupLayout layout = new GroupLayout(contentPanel);
layout.setAutoCreateGaps(true);
layout.setAutoCreateContainerGaps(true);
contentPanel.setLayout(layout);
getContentPane().add(contentPanel, BorderLayout.PAGE_START);
getContentPane().add(controlPanel, BorderLayout.PAGE_END);
controlPanel.add(cancelBtn);
controlPanel.add(predictBtn);
ilpPath.setMinimumSize(new Dimension(400, 20));

layout.setHorizontalGroup(layout
.createSequentialGroup()
.addGroup(layout
.createParallelGroup(GroupLayout.Alignment.LEADING)
.addComponent(ilpPathLabel)
.addComponent(outputTypeLabel)
.addComponent(predictionMaskLabel)
)
.addGroup(layout
.createParallelGroup(GroupLayout.Alignment.LEADING)
.addGroup(layout
.createSequentialGroup()
.addComponent(ilpPath)
.addComponent(ilpPathBrowse)
)
.addComponent(outputType)
.addComponent(predictionMask)
)
);
layout.setVerticalGroup(layout
.createSequentialGroup()
.addGroup(
layout.createParallelGroup(GroupLayout.Alignment.BASELINE)
.addComponent(ilpPathLabel)
.addComponent(ilpPath)
.addComponent(ilpPathBrowse)
)
.addGroup(
layout.createParallelGroup(GroupLayout.Alignment.BASELINE)
.addComponent(outputTypeLabel)
.addComponent(outputType)
)
.addGroup(
layout.createParallelGroup(GroupLayout.Alignment.BASELINE)
.addComponent(predictionMaskLabel)
.addComponent(predictionMask)
)
);
layout.linkSize(SwingConstants.VERTICAL, ilpPathBrowse, ilpPath);
setResizable(true);
pack();
}

private void initIlpControl() {
ilpPathBrowse.addActionListener(actionEvent -> {
File parent = model.getIlastikProjectFile();
File result = uiService.chooseFile(parent, FileWidget.OPEN_STYLE);
if (result != null) {
model.setIlastikProjectFile(result);
}
});

ilpPath.getDocument().addDocumentListener(new DocumentListener() {
@Override
public void insertUpdate(DocumentEvent documentEvent) {
model.setIlastikProjectFile(new File(ilpPath.getText()));
}

@Override
public void removeUpdate(DocumentEvent documentEvent) {
model.setIlastikProjectFile(new File(ilpPath.getText()));
}

@Override
public void changedUpdate(DocumentEvent documentEvent) {
model.setIlastikProjectFile(new File(ilpPath.getText()));
}
});
}

private void initOutputTypeControl() {
outputType.addItem(UiConstants.PIXEL_PREDICTION_TYPE_PROBABILITIES);
outputType.addItem(UiConstants.PIXEL_PREDICTION_TYPE_SEGMENTATION);
outputType.addActionListener(evt -> {
String entry = (String)outputType.getSelectedItem();
if (entry != null) {
model.setOutputType(entry);
}
});
}

private void initPredictionTypeControl() {
this.predictionMask.addItem(new DatasetComboboxEntry("<none>", null));
for (Dataset ds : datasetService.getDatasets()) {
this.predictionMask.addItem(new DatasetComboboxEntry(ds.getName(), ds));
}
predictionMask.addActionListener(evt -> {
DatasetComboboxEntry entry = (DatasetComboboxEntry) predictionMask.getSelectedItem();
if (entry != null) {
model.setPredictionMask(entry.dataset);
}
});
}
public boolean wasCancelled() {
return this.cancelled;
}

public IlastikPixelClassificationDialog(LogService logService, UIService uiService, DatasetService datasetService, IlastikPixelClassificationModel model) {
this.setModalityType(ModalityType.APPLICATION_MODAL); // Block until dialog is closed
this.uiService = uiService;
this.logService = logService;
this.datasetService = datasetService;

this.model = model;
this.model.addPropertyChangeListener(this);

this.initIlpControl();
this.initOutputTypeControl();
this.initPredictionTypeControl();

this.cancelBtn.addActionListener(evt -> {
this.dispose();
});
this.predictBtn.addActionListener(evt -> {
if (model.isValid()) {
cancelled = false;
this.dispose();
}
});

this.initializeComponentLayout();
}

private static Border getLineBorder(boolean valid) {
if (valid) {
return VALID_BORDER;
} else {
return INVALID_BORDER;
}
}

@Override
public void propertyChange(PropertyChangeEvent evt) {
if (evt.getPropertyName().equals(IlastikPixelClassificationModel.PROPERTY_ILASTIK_PROJECT_FILE)) {
File newProjectFile = (File) evt.getNewValue();
if (newProjectFile != null && !this.ilpPath.equals(newProjectFile.getAbsolutePath())) {
this.ilpPath.setText(newProjectFile.getAbsolutePath());
}

ilpPath.setBorder(getLineBorder(model.isValidIlastikProjectFile()));
} else if (evt.getPropertyName().equals(IlastikPixelClassificationModel.PROPERTY_OUTPUT_TYPE)) {
String newType = (String) evt.getNewValue();
int selectedIdx = outputType.getSelectedIndex();
for (int i = 0; i < outputType.getItemCount(); i++) {
String entry = outputType.getItemAt(i);
if (entry.equals(newType) && selectedIdx != i) {
outputType.setSelectedIndex(i);
}
}

} else if (evt.getPropertyName().equals(IlastikPixelClassificationModel.PROPERTY_PREDICTION_MASK)) {
Dataset newPredMask = (Dataset) evt.getNewValue();
int selectedIdx = predictionMask.getSelectedIndex();
if (newPredMask != null) {
for (int i = 0; i < predictionMask.getItemCount(); i++) {
DatasetComboboxEntry entry = predictionMask.getItemAt(i);
if (newPredMask.getName().equals(entry.title) && selectedIdx != i) {
predictionMask.setSelectedIndex(i);
}
}
} else {
if (selectedIdx != 0) {
outputType.setSelectedIndex(0);
}
}
}
}

}
Loading