diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java index f36e2863..898482d3 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java @@ -27,6 +27,7 @@ import java.util.Map.Entry; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.LongStream; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; import io.bioimage.modelrunner.bioimageio.description.TensorSpec; @@ -50,7 +51,11 @@ public class PatchGridCalculator & NativeType> /** * MAp containing the {@link PatchSpec} for each of the tensors defined in the rdf.yaml specs file */ - private LinkedHashMap psMap; + private LinkedHashMap inputTilesSpecs; + /** + * MAp containing the {@link PatchSpec} for each of the tensors defined in the rdf.yaml specs file + */ + private LinkedHashMap outputTilesSpecs; /** * Class to calculate the patch specifications given a series of inputs @@ -131,10 +136,10 @@ PatchGridCalculator build(ModelDescriptor model, List> inputImagesL * @throws IllegalArgumentException if one tensor that allows tiling needs more patches * in any given axis than the others */ - public LinkedHashMap get() throws IllegalArgumentException + public LinkedHashMap getInputTensorsTileSpecs() throws IllegalArgumentException { - if (psMap != null) - return psMap; + if (this.inputTilesSpecs != null) + return inputTilesSpecs; List inputTensors = findInputImageTensorSpec(); List outputTensors = findOutputImageTensorSpec(); List> inputImages = inputTensors.stream() @@ -144,11 +149,10 @@ public LinkedHashMap get() throws IllegalArgumentException throw new IllegalArgumentException("No inputs have been provided that match the " + "specified input tensors specified in the rdf.yaml file."); LinkedHashMap specsMap = computePatchSpecsForEveryTensor(inputTensors, inputImages); - LinkedHashMap outSpecsMap = computePatchSpecsForEveryOutputTensor(outputTensors, specsMap); // Check that the obtained patch specs are not going to cause errors checkPatchSpecs(specsMap); - psMap = specsMap; - return psMap; + inputTilesSpecs = specsMap; + return inputTilesSpecs; } /** @@ -231,25 +235,21 @@ private LinkedHashMap computePatchSpecsForEveryTensor(List computePatchSpecsForEveryOutputTensor(List outTensors, - Map inSpecs){ + public LinkedHashMap getOutputTensorsTileSpecs(List outTensors, + Map inSpecs) throws IllegalArgumentException { + if (this.outputTilesSpecs != null) + return outputTilesSpecs; LinkedHashMap patchInfoList = new LinkedHashMap(); for (int i = 0; i < outTensors.size(); i ++) { String refTensor = outTensors.get(i).getShape().getReferenceInput(); PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor); + if (refSpec == null) + throw new IllegalArgumentException("Please first calculate the tile specs for th einput tensors. Call: " + + "getInputTensorsTileSpecs()"); patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec)); } - return patchInfoList; + outputTilesSpecs = patchInfoList; + return outputTilesSpecs; } /** @@ -315,13 +315,11 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval } long[] shapeLong = rai.dimensionsAsLongArray(); - int[] shapeInt = new int[shapeLong.length]; - for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];} int[] patchGridSize = new int[shapeLong.length]; for (int i = 0; i < patchGridSize.length; i ++) patchGridSize[i] = 1; if (descriptor.isTilingAllowed()) { patchGridSize = IntStream.range(0, tileSize.length) - .map(i -> (int) Math.ceil((double) shapeInt[i] / ((double) tileSize[i] - halo[i] * 2))) + .map(i -> (int) Math.ceil((double) shapeLong[i] / ((double) tileSize[i] - halo[i] * 2))) .toArray(); } // For the cases when the patch is bigger than the image size, share the @@ -329,57 +327,38 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval paddingSize[0] = IntStream.range(0, tileSize.length) .map(i -> (int) Math.max(paddingSize[0][i], - Math.ceil( (double) (tileSize[i] - shapeInt[i]) / 2)) + Math.ceil( (double) (tileSize[i] - shapeLong[i]) / 2)) ).toArray(); paddingSize[1] = IntStream.range(0, tileSize.length) .map(i -> (int) Math.max( paddingSize[1][i], - tileSize[i] - shapeInt[i] - paddingSize[0][i])).toArray(); + tileSize[i] - shapeLong[i] - paddingSize[0][i])).toArray(); return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray()); } private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec) { - String processingAxesOrder = spec.getAxesOrder(); - int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(tileSize, spec.getAxesOrder(), - processingAxesOrder); - int[][] paddingSize = new int[2][5]; + int[] inputTileGrid = refSpec.getPatchGridSize(); // REgard that the input halo represents the output halo + offset // and must be divisible by 0.5. - float[] halo = arrayToWantedAxesOrderAddZeros(spec.getHalo(), - spec.getAxesOrder(), - processingAxesOrder); - if (!descriptor.isPyramidal() && spec.getTiling()) { - // In the case that padding is asymmetrical, the left upper padding has the extra pixel - for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);} - // In the case that padding is asymmetrical, the right bottom padding has one pixel less - for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);} - + int[][] paddingSize = refSpec.getPatchPaddingSize(); + int[] tileSize; + long[] shapeLong; + if (spec.getShape().getReferenceInput() == null) { + tileSize = spec.getShape().getPatchRecomendedSize(); + shapeLong = LongStream.range(0, spec.getAxesOrder().length()) + .map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i]) + .toArray(); + } else { + tileSize = IntStream.range(0, spec.getAxesOrder().length()) + .map(i -> (int) (refSpec.getPatchInputSize()[i] * spec.getShape().getScale()[i] + 2 * spec.getShape().getOffset()[i])) + .toArray(); + shapeLong = LongStream.range(0, spec.getAxesOrder().length()) + .map(i -> (int) (refSpec.getPatchInputSize()[(int) i] * spec.getShape().getScale()[(int) i] + + 2 * spec.getShape().getOffset()[(int) i])).toArray(); } - long[] shapeLong = rai.dimensionsAsLongArray(); - int[] shapeInt = new int[shapeLong.length]; - for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];} - int[] inputSequenceSize = arrayToWantedAxesOrderAddOnes(shapeInt, - spec.getAxesOrder(), - processingAxesOrder); - int[] patchGridSize = new int[] {1, 1, 1, 1, 1}; - if (descriptor.isTilingAllowed()) { - patchGridSize = IntStream.range(0, inputPatchSize.length) - .map(i -> (int) Math.ceil((double) inputSequenceSize[i] / ((double) inputPatchSize[i] - halo[i] * 2))) - .toArray(); - } - // For the cases when the patch is bigger than the image size, share the - // padding between both sides of the image - paddingSize[0] = IntStream.range(0, inputPatchSize.length) - .map(i -> - (int) Math.max(paddingSize[0][i], - Math.ceil( (double) (inputPatchSize[i] - inputSequenceSize[i]) / 2)) - ).toArray(); - paddingSize[1] = IntStream.range(0, inputPatchSize.length) - .map(i -> (int) Math.max( paddingSize[1][i], - inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray(); - return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray()); + return PatchSpec.create(spec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong); } /**