Skip to content

Commit

Permalink
allow creatig empty shm segments rjust by reserving memory
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 13, 2023
1 parent 7296af1 commit d9e17b0
Showing 1 changed file with 69 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ public class SharedMemoryArrayWin implements SharedMemoryArray
/**
* Pointer referencing the shared memory byte array
*/
private Pointer pSharedMemory;
private Pointer mappedPointer;
/**
* Pointer referencing the shared memory byte array
*/
private Pointer writePointer;
/**
* Name of the file containing the shared memory segment. In Unix based systems consits of "/" + file_name.
* In Linux the shared memory segments can be inspected at /dev/shm.
Expand Down Expand Up @@ -97,6 +101,8 @@ public class SharedMemoryArrayWin implements SharedMemoryArray
* of bytes corresponding to the values of the array, no header
*/
private boolean isNumpyFormat = false;
private static final int SEC_RESERVE = 0x4000000;
private static final int DEFAULT_RESERVED_MEMORY = 1024 * 1024 * 1024 * 2;

/**
* Create a shared memory segment with the wanted size, where an object of a certain datatype and
Expand Down Expand Up @@ -136,10 +142,17 @@ private SharedMemoryArrayWin(String name, int size, String dtype, long[] shape)
this.originalDataType = dtype;
this.originalDims = shape;
this.size = size;
int flag = WinNT.PAGE_READWRITE;
boolean write = true;
if (size < 1) {
flag = WinNT.PAGE_READWRITE | SEC_RESERVE;
size = DEFAULT_RESERVED_MEMORY;
write = false;
}
hMapFile = Kernel32.INSTANCE.CreateFileMapping(
WinBase.INVALID_HANDLE_VALUE,
null,
WinNT.PAGE_READWRITE,
flag,
0,
size,
memoryName
Expand All @@ -151,19 +164,29 @@ private SharedMemoryArrayWin(String name, int size, String dtype, long[] shape)
}

// Map the shared memory
pSharedMemory = Kernel32.INSTANCE.MapViewOfFile(
mappedPointer = Kernel32.INSTANCE.MapViewOfFile(
hMapFile,
WinNT.FILE_MAP_WRITE,
0,
0,
size
);

if (pSharedMemory == null) {
if (mappedPointer == null) {
Kernel32.INSTANCE.CloseHandle(hMapFile);
throw new RuntimeException("Error creating shared memory array. MapViewOfFile failed: "
+ "" + Kernel32.INSTANCE.GetLastError());
}
if (write) {
writePointer = Kernel32.INSTANCE.VirtualAllocEx(Kernel32.INSTANCE.GetCurrentProcess(),
mappedPointer,
new BaseTSD.SIZE_T(size), WinNT.MEM_COMMIT, WinNT.PAGE_READWRITE);
if (writePointer == null) {
close();
throw new RuntimeException("Error committing to the shared memory pages. Errno: "
+ "" + Kernel32.INSTANCE.GetLastError());
}
}
}

/**
Expand All @@ -184,7 +207,7 @@ public String getNameForPython() {
* {@inheritDoc}
*/
public Pointer getPointer() {
return this.pSharedMemory;
return this.writePointer;
}

/**
Expand Down Expand Up @@ -238,6 +261,9 @@ protected static <T extends RealType<T> & NativeType<T>> SharedMemoryArrayWin bu
if (!name.startsWith("Local" + File.separator) && !name.startsWith("Global" + File.separator))
name = "Local" + File.separator+ name;
SharedMemoryArrayWin shma = null;
if (rai == null) {
shma = new SharedMemoryArrayWin(name, -1, null, null);
}
if (Util.getTypeFromInterval(rai) instanceof ByteType) {
int size = 1;
for (long i : rai.dimensionsAsLongArray()) {size *= i;}
Expand Down Expand Up @@ -351,7 +377,7 @@ protected static <T extends RealType<T> & NativeType<T>> SharedMemoryArrayWin bu
*/
private void addByteArray(byte[] arr) {
for (int i = 0; i < arr.length; i ++) {
this.pSharedMemory.setByte(i, arr[i]);
this.writePointer.setByte(i, arr[i]);
}
}

Expand All @@ -362,7 +388,7 @@ private void buildInt8(RandomAccessibleInterval<ByteType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setByte(i ++, cursor.get().get());
this.writePointer.setByte(i ++, cursor.get().get());
}
}

Expand All @@ -373,7 +399,7 @@ private void buildUint8(RandomAccessibleInterval<UnsignedByteType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setByte(i ++, cursor.get().getByte());
this.writePointer.setByte(i ++, cursor.get().getByte());
}
}

Expand All @@ -384,7 +410,7 @@ private void buildInt16(RandomAccessibleInterval<ShortType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setShort((i * Short.BYTES), cursor.get().get());
this.writePointer.setShort((i * Short.BYTES), cursor.get().get());
i ++;
}
}
Expand All @@ -396,7 +422,7 @@ private void buildUint16(RandomAccessibleInterval<UnsignedShortType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setShort((i * Short.BYTES), cursor.get().getShort());
this.writePointer.setShort((i * Short.BYTES), cursor.get().getShort());
i ++;
}
}
Expand All @@ -408,7 +434,7 @@ private void buildInt32(RandomAccessibleInterval<IntType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setInt((i * Integer.BYTES), cursor.get().get());
this.writePointer.setInt((i * Integer.BYTES), cursor.get().get());
i ++;
}
}
Expand All @@ -420,7 +446,7 @@ private void buildUint32(RandomAccessibleInterval<UnsignedIntType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setInt((i * Integer.BYTES), cursor.get().getInt());
this.writePointer.setInt((i * Integer.BYTES), cursor.get().getInt());
i ++;
}
}
Expand All @@ -432,7 +458,7 @@ private void buildInt64(RandomAccessibleInterval<LongType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setLong((i * Long.BYTES), cursor.get().get());
this.writePointer.setLong((i * Long.BYTES), cursor.get().get());
i ++;
}
}
Expand All @@ -444,7 +470,7 @@ private void buildFloat32(RandomAccessibleInterval<FloatType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setFloat((i * Float.BYTES), cursor.get().get());
this.writePointer.setFloat((i * Float.BYTES), cursor.get().get());
i ++;
}
}
Expand All @@ -456,7 +482,7 @@ private void buildFloat64(RandomAccessibleInterval<DoubleType> tensor)
long i = 0;
while (cursor.hasNext()) {
cursor.fwd();
this.pSharedMemory.setDouble((i * Double.BYTES), cursor.get().get());
this.writePointer.setDouble((i * Double.BYTES), cursor.get().get());
i ++;
}
}
Expand All @@ -469,19 +495,42 @@ private void buildFloat64(RandomAccessibleInterval<DoubleType> tensor)
*/
public void close() {
if (unlinked) return;
Kernel32.INSTANCE.UnmapViewOfFile(pSharedMemory);
Kernel32.INSTANCE.UnmapViewOfFile(mappedPointer);
Kernel32.INSTANCE.CloseHandle(hMapFile);
unlinked = true;
}

public static void main(String[] args) {
String memoryName = "Local" + File.separator + "wnsm_52f561c9";
WinNT.HANDLE hMapFile = Kernel32.INSTANCE.OpenFileMapping(
WinNT.FILE_MAP_READ, false, memoryName
);

WinNT.HANDLE hMapFile = Kernel32.INSTANCE.CreateFileMapping(
WinBase.INVALID_HANDLE_VALUE,
null,
WinNT.PAGE_READWRITE | SEC_RESERVE,
0,
1024 * 1024 * 1024 * 3,
memoryName
);
if (hMapFile == null) {
throw new RuntimeException("OpenFileMapping failed with error: " + Kernel32.INSTANCE.GetLastError());
}

// Map the shared memory
Pointer dpSharedMemory = Kernel32.INSTANCE.MapViewOfFile(
hMapFile,
WinNT.FILE_MAP_WRITE,
0,
0,
1024 * 1024 * 1024 * 3
);
Kernel32.INSTANCE.UnmapViewOfFile(dpSharedMemory);
Kernel32.INSTANCE.CloseHandle(hMapFile);
Pointer aa = Kernel32.INSTANCE.VirtualAllocEx(Kernel32.INSTANCE.GetCurrentProcess(),
dpSharedMemory,
new BaseTSD.SIZE_T(1024 * 1024 * 2000), WinNT.MEM_COMMIT, WinNT.PAGE_READWRITE);
for (int i = 0; i < 1024*1024*2000; i ++)
aa.setByte(i, (byte) i);
if (true) return;

// Map the shared memory object into the current process's address space
Pointer pSharedMemory = Kernel32.INSTANCE.MapViewOfFile(
Expand Down Expand Up @@ -805,7 +854,7 @@ RandomAccessibleInterval<T> buildFromSharedMemoryBlock(Pointer pSharedMemory, lo
* {@inheritDoc}
*/
public <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> getSharedRAI() {
return buildFromSharedMemoryBlock(pSharedMemory, this.originalDims, false, this.originalDataType);
return buildFromSharedMemoryBlock(writePointer, this.originalDims, false, this.originalDataType);
}

@Override
Expand Down

0 comments on commit d9e17b0

Please sign in to comment.