Skip to content

Commit

Permalink
Improve input validation #21
Browse files Browse the repository at this point in the history
* 3D networks can now run 2D images
* check for fixed dimension sizes (like channel dim 2 for IsoNet)
* introducing InputValidator, this still has to be improved
  • Loading branch information
frauzufall committed Nov 28, 2018
1 parent 5012d68 commit 1aec0d6
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 73 deletions.
18 changes: 0 additions & 18 deletions src/main/java/org/csbdeep/commands/GenericIsotropicNetwork.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,6 @@ public List<RandomAccessibleInterval<FloatType>> run(final Dataset input, int nu

}

@Override
public void run() {
tryToInitialize();
boolean validInput = DatasetHelper.validate(getInput(),
"4D image with size order X-Y-C-Z and two channels, checking if the order is X-Y-Z-C (also valid)",
isHeadless(), OptionalLong.empty(), OptionalLong.empty(), OptionalLong.of(2), OptionalLong.empty());
if(!validInput) {
validInput = DatasetHelper.validate(getInput(),
"4D image with size order X-Y-Z-Y and two channels",
isHeadless(), OptionalLong.empty(), OptionalLong.empty(), OptionalLong.empty(), OptionalLong.of(2));
}
if(!validInput) return;

final AxisType[] mapping = { Axes.Z, Axes.Y, Axes.X, Axes.CHANNEL };
setMapping(mapping);
super.run();
}

@Override
public void dispose() {
super.dispose();
Expand Down
61 changes: 44 additions & 17 deletions src/main/java/org/csbdeep/commands/GenericNetwork.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import javax.swing.*;

import net.imglib2.exception.IncompatibleTypeException;
import org.csbdeep.io.DefaultInputProcessor;
import org.csbdeep.io.DefaultOutputProcessor;
import org.csbdeep.io.InputProcessor;
Expand Down Expand Up @@ -120,6 +121,7 @@ public class GenericNetwork implements

private boolean modelNeedsInitialization = false;
private boolean networkInitialized;
private boolean networkAndInputCompatible;

public enum NetworkInputSourceType { UNSET, FILE, URL }

Expand Down Expand Up @@ -172,6 +174,7 @@ public enum NetworkInputSourceType { UNSET, FILE, URL }
protected Tiling tiling;

protected InputProcessor inputProcessor;
protected InputValidator inputValidator;
protected InputMapper inputMapper;
protected InputNormalizer inputNormalizer;
protected InputTiler inputTiler;
Expand Down Expand Up @@ -326,6 +329,7 @@ protected boolean initNetwork() {
}

protected void initTasks() {
inputValidator = initInputValidator();
inputMapper = initInputMapper();
inputProcessor = initInputProcessor();
inputNormalizer = initInputNormalizer();
Expand All @@ -339,14 +343,18 @@ protected void initTasks() {
protected void initTaskManager() {
final TaskForceManager tfm = new TaskForceManager(isHeadless() || !showProgressDialog, log);
tfm.initialize();
tfm.createTaskForce("Preprocessing", modelLoader, inputMapper,
tfm.createTaskForce("Preprocessing", modelLoader, inputValidator, inputMapper,
inputProcessor, inputNormalizer);
tfm.createTaskForce("Tiling", inputTiler);
tfm.createTaskForce("Execution", modelExecutor);
tfm.createTaskForce("Postprocessing", outputTiler, outputProcessor);
taskManager = tfm;
}

protected InputValidator initInputValidator() {
return new DefaultInputValidator();
}

protected InputMapper initInputMapper() {
return new DefaultInputMapper();
}
Expand Down Expand Up @@ -405,14 +413,14 @@ public void run() {
protected void mainThread() throws OutOfMemoryError {

tryToInitialize();
taskManager.finalizeSetup();
solveModelSource();
initiateModelIfNeeded();
if(!networkAndInputCompatible) return;

updateCacheName();
savePreferences();

taskManager.finalizeSetup();

try {
tryToPrepareInputAndNetwork();
}
Expand Down Expand Up @@ -474,20 +482,29 @@ protected boolean doInputNormalization() {
return normalizeInput;
}

protected void tryToPrepareInputAndNetwork() throws MissingResourceException, FileNotFoundException {
protected void tryToPrepareInputAndNetwork() throws FileNotFoundException {

networkAndInputCompatible = false;

modelName = cacheName;

if(!networkInitialized)
initNetwork();

if(modelFileUrl.isEmpty()) {
throw new MissingResourceException("Trained model file / URL is missing or unavailable", this.getClass().getSimpleName(), modelFileUrl);
taskManager.logError("Trained model file / URL is missing or unavailable");
}
modelLoader.run(modelName, network, modelFileUrl, getInput());
if(modelLoader.isFailed()) return;
inputMapper.run(getInput(), network);

try {
inputValidator.run(getInput(), network);
}
catch(IncompatibleTypeException e) {
taskManager.logError(e.getMessage());
return;
}
inputMapper.run(getInput(), network);
networkAndInputCompatible = !inputMapper.isFailed();
}

private void savePreferences() {
Expand Down Expand Up @@ -529,8 +546,9 @@ protected List<AdvancedTiledView<FloatType>> tryToTileAndRunNetwork(

while (isOutOfMemory && canHandleOutOfMemory) {
try {
AxisType[] finalInputAxes = getAxesArray(getInput());
final List<AdvancedTiledView> tiledInput = inputTiler.run(
normalizedInput, getAxesArray(getInput()), tiling,
normalizedInput, finalInputAxes, tiling,
getTilingActions());
nTiles = tiling.getTilesNum();
if(tiledInput == null) return null;
Expand Down Expand Up @@ -569,24 +587,33 @@ protected AxisType[] getAxesArray(Dataset input, int size) {
}

public Tiling.TilingAction[] getTilingActions() {
return getTilingActionsForNode(network.getInputNode());
return getTilingActionsForNode(network.getInputNode(), network.getInputNode().getMappingIndices());
}

public static Tiling.TilingAction[] getTilingActionsForNode(ImageTensor node) {
public static Tiling.TilingAction[] getTilingActionsForNode(ImageTensor node, int[] mapping) {

if(node.getNodeShape().length == 0) return null;
Tiling.TilingAction[] actions = new Tiling.TilingAction[node.getNodeShape().length];
Arrays.fill(actions, Tiling.TilingAction.NO_TILING);
Integer indexFirst = node.getDatasetDimIndexByNodeIndex(0);
Integer indexLast = node.getDatasetDimIndexByNodeIndex(node.getNodeShape().length-1);
if(indexFirst != null) actions[indexFirst] = Tiling.TilingAction.TILE_WITHOUT_PADDING; // img batch dimension
if(indexLast != null) actions[indexLast] = Tiling.TilingAction.NO_TILING; // channel dimension
actions[0] = Tiling.TilingAction.TILE_WITHOUT_PADDING; // img batch dimension
for (int i = 1; i < node.getNodeShape().length-1; i++) {
if(node.getNodeShape()[i] < 0) {
Integer imgIndex = node.getDatasetDimIndexByNodeIndex(i);
if(imgIndex != null) actions[imgIndex] = Tiling.TilingAction.TILE_WITH_PADDING;
actions[i] = Tiling.TilingAction.TILE_WITH_PADDING;
}
}
return actions;
//permute
Tiling.TilingAction[] imgActions = new Tiling.TilingAction[node.getNodeShape().length];
for (int i = 0; i < actions.length; i++) {
imgActions[i] = actions[mapping[i]];
}
return imgActions;
}

private static int indexOf(Object[] array, Object item) {
for (int i = 0; i < array.length; i++) {
if(array[i].equals(item)) return i;
}
return -1;
}

public void setMapping(final AxisType[] mapping) {
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/csbdeep/network/DefaultInputValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

package org.csbdeep.network;

import java.util.*;

import net.imagej.axis.AxisType;
import net.imglib2.exception.IncompatibleTypeException;
import org.csbdeep.network.model.Network;
import org.csbdeep.task.DefaultTask;

import net.imagej.Dataset;
import net.imagej.axis.Axes;

public class DefaultInputValidator extends DefaultTask implements InputValidator {

@Override
public void run(final Dataset input, final Network network) throws IncompatibleTypeException {

setStarted();

for (int i = 0; i < network.getInputNode().getNodeShape().length; i++) {
AxisType axis = network.getInputNode().getNodeAxis(i);
long size = network.getInputNode().getNodeShape()[i];
if(size > 1) {
if(!input.axis(axis).isPresent()) {
throw new IncompatibleTypeException(input, "Input should have axis of type " + axis.getLabel());
}
if(input.dimension(axis) != size) {
throw new IncompatibleTypeException(input, "Input axis of type " + axis.getLabel() + " should have size " + size);
}
}
}

setFinished();

}


}
14 changes: 14 additions & 0 deletions src/main/java/org/csbdeep/network/InputValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

package org.csbdeep.network;

import net.imglib2.exception.IncompatibleTypeException;
import org.csbdeep.network.model.Network;
import org.csbdeep.task.Task;

import net.imagej.Dataset;

public interface InputValidator extends Task {

void run(Dataset input, Network network) throws IncompatibleTypeException;

}
1 change: 1 addition & 0 deletions src/main/java/org/csbdeep/tiling/DefaultTiling.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public AdvancedTiledView<T> preprocess(RandomAccessibleInterval<T> input, AxisTy
tilesNum = (int) arrayProduct(tiling);
long[] padding = getPadding(tiling);
computeBatching(input, tiling, tilingActions);
parent.log("Tiling actions: " + Arrays.toString(tilingActions));
parent.log("Dividing image into " + arrayProduct(tiling) + " tile(s)..");

RandomAccessibleInterval<T> expandedInput = expandToFitBatchSize(input,
Expand Down
70 changes: 45 additions & 25 deletions src/test/java/org/csbdeep/commands/GenericIsotropicNetworkTest.java
Original file line number Diff line number Diff line change
@@ -1,53 +1,73 @@

package org.csbdeep.commands;

import net.imagej.Dataset;
import net.imagej.axis.Axes;
import net.imagej.axis.AxisType;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.real.FloatType;
import org.csbdeep.CSBDeepTest;
import org.junit.Assert;
import org.junit.Test;
import org.scijava.command.CommandModule;
import org.scijava.module.Module;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.Assert.*;

import java.io.File;
import java.net.URL;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import static junit.framework.TestCase.assertNotNull;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import org.csbdeep.CSBDeepTest;
import org.junit.Assert;
import org.junit.Test;
import org.scijava.command.CommandModule;
import org.scijava.module.Module;

import net.imagej.Dataset;
import net.imagej.axis.Axes;
import net.imagej.axis.AxisType;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;

public class GenericIsotropicNetworkTest extends CSBDeepTest {

@Test
public void testGenericIsotropicNetwork() {
public void testCompatibleInput() {
launchImageJ();
testDataset(new FloatType(), new long[] { 10, 10, 10, 2 }, new AxisType[] {
Axes.X, Axes.Y, Axes.Z, Axes.CHANNEL });

}

@Test
public void testIncompatibleInput() {
launchImageJ();

URL networkUrl = this.getClass().getResource("isoNet/model.zip");
final Dataset input = createDataset(new FloatType(), new long[] { 10, 10, 10 }, new AxisType[] {
Axes.X, Axes.Y, Axes.Z, Axes.CHANNEL });
boolean noException = true;
try {
final Module module = ij.command().run(GenericIsotropicNetwork.class, false,
"modelFile", new File(networkUrl.getPath()), "input", input, "scale", 1.5).get();
assertNotNull(module);
assertNull(module.getOutput("output"));
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}

public <T extends RealType<T> & NativeType<T>> void testDataset(final T type,
final long[] dims, final AxisType[] axes) {

URL networkUrl = this.getClass().getResource("isoNet/model.zip");
final Dataset input = createDataset(type, dims, axes);
final Future<CommandModule> future = ij.command().run(GenericIsotropicNetwork.class, false,
"modelFile", new File(networkUrl.getPath()), "input", input, "scale", 1.5);
assertNotEquals(null, future);
final Module module = ij.module().waitFor(future);
final Dataset output = (Dataset) module.getOutput("output");
try {
final Module module = ij.command().run(GenericIsotropicNetwork.class, false,
"modelFile", new File(networkUrl.getPath()), "input", input, "scale", 1.5).get();
assertNotEquals(null, module);
final Dataset output = (Dataset) module.getOutput("output");

Assert.assertNotNull(output);
testResultAxesAndSize(input, output);
Assert.assertNotNull(output);
testResultAxesAndSize(input, output);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import java.net.URL;

import org.csbdeep.CSBDeepTest;
import org.junit.Ignore;
import org.junit.Test;
import org.scijava.module.Module;

Expand Down Expand Up @@ -44,7 +43,6 @@ public void test2DNetworkWith2DInputImage() {
}

@Test
@Ignore //TODO this crashes without catchable errors
public void test3DNetworkWith2DInputImage() {

launchImageJ();
Expand Down
10 changes: 0 additions & 10 deletions src/test/java/org/csbdeep/commands/NetTubulinTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,4 @@ public <T extends RealType<T> & NativeType<T>> void testDataset(final T type,
testResultAxesAndSize(input, output);
}

@Test
public void testUnfittingDataset()
{
launchImageJ();
final Dataset input = createDataset(new FloatType(), new long[] { 3, 4, 5 }, new AxisType[] {
Axes.Y, Axes.Z, Axes.TIME });
final Dataset output = runPlugin(NetTubulin.class, input);
assertNull(output);
}

}
3 changes: 2 additions & 1 deletion src/test/java/org/csbdeep/tasks/TilingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ public void testNetworkTiling() {
node.setNodeShape(nodeShape);
node.setMapping(new AxisType[]{Axes.TIME, Axes.Y, Axes.X, Axes.CHANNEL});

Tiling.TilingAction[] actions = GenericNetwork.getTilingActionsForNode(node);
Tiling.TilingAction[] actions = GenericNetwork.getTilingActionsForNode(
node, new int[]{2,1,0,3});

System.out.println(Arrays.toString(actions));

Expand Down

0 comments on commit 1aec0d6

Please sign in to comment.