Skip to content

Commit

Permalink
Metal: Use mtl resource shared to avoid copy (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Nov 17, 2024
1 parent 6be03dd commit cb8082b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 25 deletions.
10 changes: 5 additions & 5 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl MetalDevice {
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name)
}

/// Creates a new buffer (not necessarily zeroed).
Expand All @@ -232,7 +232,7 @@ impl MetalDevice {
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
}

/// Creates a new buffer from data.
Expand All @@ -245,12 +245,12 @@ impl MetalDevice {
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
MTLResourceOptions::StorageModeShared,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;

let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.entry((size, MTLResourceOptions::StorageModeShared))
.or_insert(vec![]);

let new_buffer = Arc::new(new_buffer);
Expand All @@ -261,7 +261,7 @@ impl MetalDevice {
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
MTLResourceOptions::StorageModeShared,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
Expand Down
26 changes: 6 additions & 20 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2051,19 +2051,12 @@ impl MetalStorage {
}

pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;

let buffer = self.device.new_buffer_managed(size)?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, self.count))

let ptr = self.buffer.contents() as *mut T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, self.count) };
Ok(slice.to_vec())
}
}

Expand All @@ -2081,7 +2074,7 @@ impl BackendDevice for MetalDevice {
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
MTLResourceOptions::StorageModeShared,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
Expand Down Expand Up @@ -2302,10 +2295,3 @@ impl BackendDevice for MetalDevice {
self.wait_until_completed()
}
}

fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}

0 comments on commit cb8082b

Please sign in to comment.