diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java index 898482d3..9f91a65a 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java @@ -141,7 +141,6 @@ public LinkedHashMap getInputTensorsTileSpecs() throws Illega if (this.inputTilesSpecs != null) return inputTilesSpecs; List inputTensors = findInputImageTensorSpec(); - List outputTensors = findOutputImageTensorSpec(); List> inputImages = inputTensors.stream() .filter(k -> this.inputValuesMap.get(k.getName()) != null) .map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList()); @@ -235,17 +234,17 @@ private LinkedHashMap computePatchSpecsForEveryTensor(List getOutputTensorsTileSpecs(List outTensors, - Map inSpecs) throws IllegalArgumentException { + public LinkedHashMap getOutputTensorsTileSpecs() throws IllegalArgumentException { + if (this.inputTilesSpecs == null) + throw new IllegalArgumentException("Please first calculate the tile specs for the input tensors. Call: " + + "getInputTensorsTileSpecs()"); if (this.outputTilesSpecs != null) return outputTilesSpecs; + List outTensors = findOutputImageTensorSpec(); LinkedHashMap patchInfoList = new LinkedHashMap(); for (int i = 0; i < outTensors.size(); i ++) { String refTensor = outTensors.get(i).getShape().getReferenceInput(); - PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor); - if (refSpec == null) - throw new IllegalArgumentException("Please first calculate the tile specs for th einput tensors. Call: " - + "getInputTensorsTileSpecs()"); + PatchSpec refSpec = refTensor == null ? inputTilesSpecs.values().stream().findFirst().get() : inputTilesSpecs.get(refTensor); patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec)); } outputTilesSpecs = patchInfoList; @@ -336,29 +335,29 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray()); } - private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec) + private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchSpec refTilesSpec) { - int[] inputTileGrid = refSpec.getPatchGridSize(); + int[] inputTileGrid = refTilesSpec.getPatchGridSize(); // REgard that the input halo represents the output halo + offset // and must be divisible by 0.5. - int[][] paddingSize = refSpec.getPatchPaddingSize(); + int[][] paddingSize = refTilesSpec.getPatchPaddingSize(); int[] tileSize; long[] shapeLong; - if (spec.getShape().getReferenceInput() == null) { - tileSize = spec.getShape().getPatchRecomendedSize(); - shapeLong = LongStream.range(0, spec.getAxesOrder().length()) + if (tensorSpec.getShape().getReferenceInput() == null) { + tileSize = tensorSpec.getShape().getPatchRecomendedSize(); + shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length()) .map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i]) .toArray(); } else { - tileSize = IntStream.range(0, spec.getAxesOrder().length()) - .map(i -> (int) (refSpec.getPatchInputSize()[i] * spec.getShape().getScale()[i] + 2 * spec.getShape().getOffset()[i])) + tileSize = IntStream.range(0, tensorSpec.getAxesOrder().length()) + .map(i -> (int) (refTilesSpec.getPatchInputSize()[i] * tensorSpec.getShape().getScale()[i] + 2 * tensorSpec.getShape().getOffset()[i])) .toArray(); - shapeLong = LongStream.range(0, spec.getAxesOrder().length()) - .map(i -> (int) (refSpec.getPatchInputSize()[(int) i] * spec.getShape().getScale()[(int) i] - + 2 * spec.getShape().getOffset()[(int) i])).toArray(); + shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length()) + .map(i -> (int) (refTilesSpec.getTensorDims()[(int) i] * tensorSpec.getShape().getScale()[(int) i] + + 2 * tensorSpec.getShape().getOffset()[(int) i])).toArray(); } - return PatchSpec.create(spec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong); + return PatchSpec.create(tensorSpec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong); } /** diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java index c6ec50f7..5ae48f5e 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java @@ -62,9 +62,11 @@ private TileGrid() /** */ - public static TileGrid create(PatchSpec tileSpecs, long[] imageDims) + public static TileGrid create(PatchSpec tileSpecs) { TileGrid ps = new TileGrid(); + ps.tensorName = tileSpecs.getTensorName(); + long[] imageDims = tileSpecs.getTensorDims(); int[] gridSize = tileSpecs.getPatchGridSize(); ps.tileSize = tileSpecs.getPatchInputSize(); int tileCount = Arrays.stream(gridSize).reduce(1, (a, b) -> a * b); @@ -87,6 +89,30 @@ public static TileGrid create(PatchSpec tileSpecs, long[] imageDims) } return ps; } + + public String getTensorName() { + return tensorName; + } + + public int[] getTileSize() { + return this.tileSize; + } + + public int[] getRoiSize() { + return this.roiSize; + } + + public List getTilePostionsInImage() { + return this.tilePostionsInImage; + } + + public List getRoiPositionsInTile() { + return this.roiPositionsInTile; + } + + public List getRoiPostionsInImage() { + return this.roiPositionsInImage; + } @Override public String toString()