Skip to content

Commit

Permalink
corerct ultramegabug that was avoiding copying from tensor to shm
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 23, 2024
1 parent f33dc43 commit a937f90
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,18 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
IValue output = model.forward(inputsVector);
TensorVector outputTensorVector = null;
if (output.isTensorList()) {
System.out.println("entered 1");
outputTensorVector = output.toTensorVector();
} else {
System.out.println("entered 2");
outputTensorVector = new TensorVector();
outputTensorVector.put(output.toTensor());
}

// Fill the agnostic output tensors list with data from the inference result
int c = 0;
for (String ee : outputs) {
System.out.println(ee);
Map<String, Object> decoded = Types.decode(ee);
ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.bioimage.modelrunner.utils.CommonUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;

import org.bytedeco.pytorch.Tensor;
Expand Down Expand Up @@ -106,7 +107,11 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
long flatSize = 1;
for (long l : arrayShape) {flatSize *= l;}
ByteBuffer byteBuffer = ByteBuffer.allocate((int) (flatSize * Float.BYTES));
tensor.data_ptr_float().get(byteBuffer.asFloatBuffer().array());
shma.getDataBufferNoHeader().put(byteBuffer);
if (PlatformDetection.isWindows()) shma.close();
}

Expand Down

0 comments on commit a937f90

Please sign in to comment.