From 4c84784cda3edc321561c06b5bfa6e05fcc2a380 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 16 Nov 2024 20:51:46 -0500 Subject: [PATCH] Use mtl resource shared to avoid copy --- candle-core/src/metal_backend/device.rs | 10 +++++----- candle-core/src/metal_backend/mod.rs | 26 ++++++------------------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index b100db63d4..9c5ba5bb5e 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -223,7 +223,7 @@ impl MetalDevice { name: &str, ) -> Result> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name) } /// Creates a new buffer (not necessarily zeroed). @@ -232,7 +232,7 @@ impl MetalDevice { /// synchronization when the CPU memory is modified /// Used as a bridge to gather data back from the GPU pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { - self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, "managed") } /// Creates a new buffer from data. @@ -245,12 +245,12 @@ impl MetalDevice { let new_buffer = self.device.new_buffer_with_data( data.as_ptr() as *const c_void, size, - MTLResourceOptions::StorageModeManaged, + MTLResourceOptions::StorageModeShared, ); let mut buffers = self.buffers.write().map_err(MetalError::from)?; let subbuffers = buffers - .entry((size, MTLResourceOptions::StorageModeManaged)) + .entry((size, MTLResourceOptions::StorageModeShared)) .or_insert(vec![]); let new_buffer = Arc::new(new_buffer); @@ -261,7 +261,7 @@ impl MetalDevice { pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { let buffer = self.allocate_buffer( size_in_bytes as NSUInteger, - MTLResourceOptions::StorageModePrivate, + MTLResourceOptions::StorageModeShared, "allocate_zeros", )?; let command_buffer = self.command_buffer()?; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 339ed1e895..928645867f 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2051,19 +2051,12 @@ impl MetalStorage { } pub(crate) fn to_cpu(&self) -> Result> { - let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger; - - let buffer = self.device.new_buffer_managed(size)?; - { - let command_buffer = self.device.command_buffer()?; - command_buffer.set_label("to_cpu"); - let blit = command_buffer.new_blit_command_encoder(); - blit.set_label("blit_to_cpu"); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size); - blit.end_encoding(); - } self.device.wait_until_completed()?; - Ok(read_to_vec(&buffer, self.count)) + + let ptr = self.buffer.contents() as *mut T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, self.count) }; + Ok(slice.to_vec()) } } @@ -2081,7 +2074,7 @@ impl BackendDevice for MetalDevice { let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, 4, - MTLResourceOptions::StorageModeManaged, + MTLResourceOptions::StorageModeShared, ))); let commands = device::Commands::new(command_queue)?; Ok(Self { @@ -2302,10 +2295,3 @@ impl BackendDevice for MetalDevice { self.wait_until_completed() } } - -fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { - let ptr = buffer.contents() as *const T; - assert!(!ptr.is_null()); - let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; - slice.to_vec() -}