From cb8082bf28eadb4140c1774f002983a6d182bf3e Mon Sep 17 00:00:00 2001
From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com>
Date: Sat, 16 Nov 2024 20:58:24 -0500
Subject: [PATCH] Metal: Use mtl resource shared to avoid copy (#40)

---
 candle-core/src/metal_backend/device.rs | 10 +++++-----
 candle-core/src/metal_backend/mod.rs    | 26 ++++++-------------------
 2 files changed, 11 insertions(+), 25 deletions(-)

diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs
index b100db63d4..9c5ba5bb5e 100644
--- a/candle-core/src/metal_backend/device.rs
+++ b/candle-core/src/metal_backend/device.rs
@@ -223,7 +223,7 @@ impl MetalDevice {
         name: &str,
     ) -> Result<Arc<Buffer>> {
         let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
-        self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
+        self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name)
     }
 
     /// Creates a new buffer (not necessarily zeroed).
@@ -232,7 +232,7 @@ impl MetalDevice {
     /// synchronization when the CPU memory is modified
     /// Used as a bridge to gather data back from the GPU
     pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
-        self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
+        self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
     }
 
     /// Creates a new buffer from data.
@@ -245,12 +245,12 @@ impl MetalDevice {
         let new_buffer = self.device.new_buffer_with_data(
             data.as_ptr() as *const c_void,
             size,
-            MTLResourceOptions::StorageModeManaged,
+            MTLResourceOptions::StorageModeShared,
         );
         let mut buffers = self.buffers.write().map_err(MetalError::from)?;
 
         let subbuffers = buffers
-            .entry((size, MTLResourceOptions::StorageModeManaged))
+            .entry((size, MTLResourceOptions::StorageModeShared))
             .or_insert(vec![]);
 
         let new_buffer = Arc::new(new_buffer);
@@ -261,7 +261,7 @@ impl MetalDevice {
     pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
         let buffer = self.allocate_buffer(
             size_in_bytes as NSUInteger,
-            MTLResourceOptions::StorageModePrivate,
+            MTLResourceOptions::StorageModeShared,
             "allocate_zeros",
         )?;
         let command_buffer = self.command_buffer()?;
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 339ed1e895..928645867f 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -2051,19 +2051,12 @@ impl MetalStorage {
     }
 
     pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
-        let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
-
-        let buffer = self.device.new_buffer_managed(size)?;
-        {
-            let command_buffer = self.device.command_buffer()?;
-            command_buffer.set_label("to_cpu");
-            let blit = command_buffer.new_blit_command_encoder();
-            blit.set_label("blit_to_cpu");
-            blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
-            blit.end_encoding();
-        }
         self.device.wait_until_completed()?;
-        Ok(read_to_vec(&buffer, self.count))
+
+        let ptr = self.buffer.contents() as *mut T;
+        assert!(!ptr.is_null());
+        let slice = unsafe { std::slice::from_raw_parts(ptr, self.count) };
+        Ok(slice.to_vec())
     }
 }
 
@@ -2081,7 +2074,7 @@ impl BackendDevice for MetalDevice {
         let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
             [299792458].as_ptr() as *const c_void,
             4,
-            MTLResourceOptions::StorageModeManaged,
+            MTLResourceOptions::StorageModeShared,
         )));
         let commands = device::Commands::new(command_queue)?;
         Ok(Self {
@@ -2302,10 +2295,3 @@ impl BackendDevice for MetalDevice {
         self.wait_until_completed()
     }
 }
-
-fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
-    let ptr = buffer.contents() as *const T;
-    assert!(!ptr.is_null());
-    let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
-    slice.to_vec()
-}