Skip to content

Commit

Permalink
improve the specification of tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent f62c4ad commit 0dde1d1
Showing 1 changed file with 40 additions and 61 deletions.
101 changes: 40 additions & 61 deletions src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Map.Entry;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
Expand All @@ -50,7 +51,11 @@ public class PatchGridCalculator <T extends RealType<T> & NativeType<T>>
/**
* MAp containing the {@link PatchSpec} for each of the tensors defined in the rdf.yaml specs file
*/
private LinkedHashMap<String, PatchSpec> psMap;
private LinkedHashMap<String, PatchSpec> inputTilesSpecs;
/**
* MAp containing the {@link PatchSpec} for each of the tensors defined in the rdf.yaml specs file
*/
private LinkedHashMap<String, PatchSpec> outputTilesSpecs;

/**
* Class to calculate the patch specifications given a series of inputs
Expand Down Expand Up @@ -131,10 +136,10 @@ PatchGridCalculator<T> build(ModelDescriptor model, List<Tensor<T>> inputImagesL
* @throws IllegalArgumentException if one tensor that allows tiling needs more patches
* in any given axis than the others
*/
public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
public LinkedHashMap<String, PatchSpec> getInputTensorsTileSpecs() throws IllegalArgumentException
{
if (psMap != null)
return psMap;
if (this.inputTilesSpecs != null)
return inputTilesSpecs;
List<TensorSpec> inputTensors = findInputImageTensorSpec();
List<TensorSpec> outputTensors = findOutputImageTensorSpec();
List<Tensor<T>> inputImages = inputTensors.stream()
Expand All @@ -144,11 +149,10 @@ public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
throw new IllegalArgumentException("No inputs have been provided that match the "
+ "specified input tensors specified in the rdf.yaml file.");
LinkedHashMap<String, PatchSpec> specsMap = computePatchSpecsForEveryTensor(inputTensors, inputImages);
LinkedHashMap<String, PatchSpec> outSpecsMap = computePatchSpecsForEveryOutputTensor(outputTensors, specsMap);
// Check that the obtained patch specs are not going to cause errors
checkPatchSpecs(specsMap);
psMap = specsMap;
return psMap;
inputTilesSpecs = specsMap;
return inputTilesSpecs;
}

/**
Expand Down Expand Up @@ -231,25 +235,21 @@ private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<Te
return patchInfoList;
}

/**
* Create list of patch specifications for every tensor aking into account the
* corresponding image
* @param tensors
* the tensor information
* @param images
* the images corresponding to each tensor
* @return the LinkedHashMap where the key corresponds to the name of the tensor and the value is its
* patch specifications
*/
private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryOutputTensor(List<TensorSpec> outTensors,
Map<String, PatchSpec> inSpecs){
public LinkedHashMap<String, PatchSpec> getOutputTensorsTileSpecs(List<TensorSpec> outTensors,
Map<String, PatchSpec> inSpecs) throws IllegalArgumentException {
if (this.outputTilesSpecs != null)
return outputTilesSpecs;
LinkedHashMap<String, PatchSpec> patchInfoList = new LinkedHashMap<String, PatchSpec>();
for (int i = 0; i < outTensors.size(); i ++) {
String refTensor = outTensors.get(i).getShape().getReferenceInput();
PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor);
if (refSpec == null)
throw new IllegalArgumentException("Please first calculate the tile specs for th einput tensors. Call: "
+ "getInputTensorsTileSpecs()");
patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec));
}
return patchInfoList;
outputTilesSpecs = patchInfoList;
return outputTilesSpecs;
}

/**
Expand Down Expand Up @@ -315,71 +315,50 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>

}
long[] shapeLong = rai.dimensionsAsLongArray();
int[] shapeInt = new int[shapeLong.length];
for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];}
int[] patchGridSize = new int[shapeLong.length];
for (int i = 0; i < patchGridSize.length; i ++) patchGridSize[i] = 1;
if (descriptor.isTilingAllowed()) {
patchGridSize = IntStream.range(0, tileSize.length)
.map(i -> (int) Math.ceil((double) shapeInt[i] / ((double) tileSize[i] - halo[i] * 2)))
.map(i -> (int) Math.ceil((double) shapeLong[i] / ((double) tileSize[i] - halo[i] * 2)))
.toArray();
}
// For the cases when the patch is bigger than the image size, share the
// padding between both sides of the image
paddingSize[0] = IntStream.range(0, tileSize.length)
.map(i ->
(int) Math.max(paddingSize[0][i],
Math.ceil( (double) (tileSize[i] - shapeInt[i]) / 2))
Math.ceil( (double) (tileSize[i] - shapeLong[i]) / 2))
).toArray();
paddingSize[1] = IntStream.range(0, tileSize.length)
.map(i -> (int) Math.max( paddingSize[1][i],
tileSize[i] - shapeInt[i] - paddingSize[0][i])).toArray();
tileSize[i] - shapeLong[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray());
}

private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec)
{
String processingAxesOrder = spec.getAxesOrder();
int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(tileSize, spec.getAxesOrder(),
processingAxesOrder);
int[][] paddingSize = new int[2][5];
int[] inputTileGrid = refSpec.getPatchGridSize();
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
float[] halo = arrayToWantedAxesOrderAddZeros(spec.getHalo(),
spec.getAxesOrder(),
processingAxesOrder);
if (!descriptor.isPyramidal() && spec.getTiling()) {
// In the case that padding is asymmetrical, the left upper padding has the extra pixel
for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);}
// In the case that padding is asymmetrical, the right bottom padding has one pixel less
for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);}

int[][] paddingSize = refSpec.getPatchPaddingSize();
int[] tileSize;
long[] shapeLong;
if (spec.getShape().getReferenceInput() == null) {
tileSize = spec.getShape().getPatchRecomendedSize();
shapeLong = LongStream.range(0, spec.getAxesOrder().length())
.map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i])
.toArray();
} else {
tileSize = IntStream.range(0, spec.getAxesOrder().length())
.map(i -> (int) (refSpec.getPatchInputSize()[i] * spec.getShape().getScale()[i] + 2 * spec.getShape().getOffset()[i]))
.toArray();
shapeLong = LongStream.range(0, spec.getAxesOrder().length())
.map(i -> (int) (refSpec.getPatchInputSize()[(int) i] * spec.getShape().getScale()[(int) i]
+ 2 * spec.getShape().getOffset()[(int) i])).toArray();
}
long[] shapeLong = rai.dimensionsAsLongArray();
int[] shapeInt = new int[shapeLong.length];
for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];}
int[] inputSequenceSize = arrayToWantedAxesOrderAddOnes(shapeInt,
spec.getAxesOrder(),
processingAxesOrder);
int[] patchGridSize = new int[] {1, 1, 1, 1, 1};
if (descriptor.isTilingAllowed()) {
patchGridSize = IntStream.range(0, inputPatchSize.length)
.map(i -> (int) Math.ceil((double) inputSequenceSize[i] / ((double) inputPatchSize[i] - halo[i] * 2)))
.toArray();
}
// For the cases when the patch is bigger than the image size, share the
// padding between both sides of the image
paddingSize[0] = IntStream.range(0, inputPatchSize.length)
.map(i ->
(int) Math.max(paddingSize[0][i],
Math.ceil( (double) (inputPatchSize[i] - inputSequenceSize[i]) / 2))
).toArray();
paddingSize[1] = IntStream.range(0, inputPatchSize.length)
.map(i -> (int) Math.max( paddingSize[1][i],
inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray());
return PatchSpec.create(spec.getName(), tileSize, inputTileGrid, paddingSize, shapeLong);
}

/**
Expand Down

0 comments on commit 0dde1d1

Please sign in to comment.