Skip to content

Commit

Permalink
start finding the path specs for the output images
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent ba55595 commit 192f121
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private PatchGridCalculator(ModelDescriptor descriptor, List<Tensor<T>> tensorLi
throws IllegalArgumentException
{
for (TensorSpec tt : descriptor.getInputTensors()) {
if (tt.isImage() && Tensor.getTensorByNameFromList(tensorList, tt.getName()) != null)
if (tt.isImage() && Tensor.getTensorByNameFromList(tensorList, tt.getName()) == null)
throw new IllegalArgumentException("Model input tensor '" + tt.getName() + "' is specified in the rdf.yaml specs file "
+ "but cannot be found in the model inputs map provided.");
// TODO change isImage() by isTensor()
Expand Down Expand Up @@ -136,13 +136,15 @@ public LinkedHashMap<String, PatchSpec> get() throws IllegalArgumentException
if (psMap != null)
return psMap;
List<TensorSpec> inputTensors = findInputImageTensorSpec();
List<TensorSpec> outputTensors = findOutputImageTensorSpec();
List<Tensor<T>> inputImages = inputTensors.stream()
.filter(k -> this.inputValuesMap.get(k.getName()) != null)
.map(k -> this.inputValuesMap.get(k.getName())).collect(Collectors.toList());
if (inputImages.size() == 0)
throw new IllegalArgumentException("No inputs have been provided that match the "
+ "specified input tensors specified in the rdf.yaml file.");
LinkedHashMap<String, PatchSpec> specsMap = computePatchSpecsForEveryTensor(inputTensors, inputImages);
LinkedHashMap<String, PatchSpec> outSpecsMap = computePatchSpecsForEveryOutputTensor(outputTensors, specsMap);
// Check that the obtained patch specs are not going to cause errors
checkPatchSpecs(specsMap);
psMap = specsMap;
Expand Down Expand Up @@ -201,6 +203,16 @@ private List<TensorSpec> findInputImageTensorSpec()
return this.descriptor.getInputTensors().stream().filter(tr -> tr.isImage())
.collect(Collectors.toList());
}

/**
* Get the output tensors that correspond to images
* @return list of tensor specs corresponding to each of the output image tensors
*/
private List<TensorSpec> findOutputImageTensorSpec()
{
return this.descriptor.getOutputTensors().stream().filter(tr -> tr.isImage())
.collect(Collectors.toList());
}

/**
* Create list of patch specifications for every tensor aking into account the
Expand All @@ -218,6 +230,27 @@ private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<Te
patchInfoList.put(tensors.get(i).getName(), computePatchSpecs(tensors.get(i), images.get(i)));
return patchInfoList;
}

/**
* Create list of patch specifications for every tensor aking into account the
* corresponding image
* @param tensors
* the tensor information
* @param images
* the images corresponding to each tensor
* @return the LinkedHashMap where the key corresponds to the name of the tensor and the value is its
* patch specifications
*/
private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryOutputTensor(List<TensorSpec> outTensors,
Map<String, PatchSpec> inSpecs){
LinkedHashMap<String, PatchSpec> patchInfoList = new LinkedHashMap<String, PatchSpec>();
for (int i = 0; i < outTensors.size(); i ++) {
String refTensor = outTensors.get(i).getShape().getReferenceInput();
PatchSpec refSpec = refTensor == null ? inSpecs.values().stream().findFirst().get() : inSpecs.get(refTensor);
patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec));
}
return patchInfoList;
}

/**
* Compute the patch details needed to perform the tiling strategy. The calculations
Expand Down Expand Up @@ -309,7 +342,51 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>
.map(i -> (int) Math.max( paddingSize[1][i],
inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize);
return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray());
}

private PatchSpec computePatchSpecsForOutputTensor(TensorSpec spec, PatchSpec refSpec)
{
String processingAxesOrder = spec.getAxesOrder();
int[] inputPatchSize = arrayToWantedAxesOrderAddOnes(tileSize, spec.getAxesOrder(),
processingAxesOrder);
int[][] paddingSize = new int[2][5];
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
float[] halo = arrayToWantedAxesOrderAddZeros(spec.getHalo(),
spec.getAxesOrder(),
processingAxesOrder);
if (!descriptor.isPyramidal() && spec.getTiling()) {
// In the case that padding is asymmetrical, the left upper padding has the extra pixel
for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);}
// In the case that padding is asymmetrical, the right bottom padding has one pixel less
for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);}

}
long[] shapeLong = rai.dimensionsAsLongArray();
int[] shapeInt = new int[shapeLong.length];
for (int i = 0; i < shapeInt.length; i ++) {shapeInt[i] = (int) shapeLong[i];}
int[] inputSequenceSize = arrayToWantedAxesOrderAddOnes(shapeInt,
spec.getAxesOrder(),
processingAxesOrder);
int[] patchGridSize = new int[] {1, 1, 1, 1, 1};
if (descriptor.isTilingAllowed()) {
patchGridSize = IntStream.range(0, inputPatchSize.length)
.map(i -> (int) Math.ceil((double) inputSequenceSize[i] / ((double) inputPatchSize[i] - halo[i] * 2)))
.toArray();
}
// For the cases when the patch is bigger than the image size, share the
// padding between both sides of the image
paddingSize[0] = IntStream.range(0, inputPatchSize.length)
.map(i ->
(int) Math.max(paddingSize[0][i],
Math.ceil( (double) (inputPatchSize[i] - inputSequenceSize[i]) / 2))
).toArray();
paddingSize[1] = IntStream.range(0, inputPatchSize.length)
.map(i -> (int) Math.max( paddingSize[1][i],
inputPatchSize[i] - inputSequenceSize[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getName(), inputPatchSize, patchGridSize, paddingSize, rai.dimensionsAsLongArray());
}

/**
Expand Down
48 changes: 47 additions & 1 deletion src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* Patch specification providing information about the patch size and patch grid size.
Expand All @@ -29,6 +30,10 @@
*/
public class PatchSpec
{
/**
* Size of the tensor that is going to be tiled
*/
private long[] tensorDims;
/**
* Size of the input patch. Following "xyczb" axes order
*/
Expand Down Expand Up @@ -63,13 +68,14 @@ public class PatchSpec
* @return The create patch specification.
*/
public static PatchSpec create(String tensorName, int[] patchInputSize, int[] patchGridSize,
int[][] patchPaddingSize)
int[][] patchPaddingSize, long[] tensorDims)
{
PatchSpec ps = new PatchSpec();
ps.patchInputSize = patchInputSize;
ps.patchGridSize = patchGridSize;
ps.patchPaddingSize = patchPaddingSize;
ps.tensorName = tensorName;
ps.tensorDims = tensorDims;
return ps;
}

Expand All @@ -78,6 +84,10 @@ private PatchSpec()
}

/**
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* Obtain the number of patches in each axes for a list of input patch specs.
* When tiling is allowed, only one patch grid is permitted. If among the tensors
* there are one or more that do not allow tiling, then two patch sizes are allowed,
Expand All @@ -101,6 +111,34 @@ public static int[] getGridSize(List<PatchSpec> patches) {
return grid;
}

/**
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* TODO this method should be per image, not in total??
* Obtain the number of patches in each axes for a list of input patch specs.
* When tiling is allowed, only one patch grid is permitted. If among the tensors
* there are one or more that do not allow tiling, then two patch sizes are allowed,
* the one for the tensors that allow tiling and the one for the ones that not (that will
* just be 1s in every axes).
* In the case there exist tensors that allow tiling, the grid size for those will be the
* one returned
* @param patches
* map containing tiling specs per tensor
* @return the number of patches in each axes
*/
public static int[] getGridSize(Map<String, PatchSpec> patches) {
// The minimum possible grid is just one patch in every direction. This is the
// grid if no tiling is allowed
int[] grid = new int[]{1, 1, 1, 1, 1};
// If there is any different grid, that will be the absolute one
for (PatchSpec pp : patches.values()) {
if (!PatchGridCalculator.compareTwoArrays(grid, pp.getPatchGridSize()))
return pp.getPatchGridSize();
}
return grid;
}

/**
* Return the PatchSpec corresponding to the tensor called by the name defined
* @param specs
Expand All @@ -120,6 +158,14 @@ public static PatchSpec getPatchSpecFromListByName(List<PatchSpec> specs, String
public String getTensorName() {
return tensorName;
}

/**
* The dimensions of the tensor
* @return the dimensions of the tensor that is going to be tiled
*/
public long[] getTensorDims() {
return tensorDims;
}

/**
* @return Input patch size. The patch taken from the input sequence including the halo.
Expand Down

0 comments on commit 192f121

Please sign in to comment.