Skip to content

Commit

Permalink
Rustier ffi impls
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jul 17, 2024
1 parent c87dd38 commit 2a2a349
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 85 deletions.
12 changes: 6 additions & 6 deletions candle-metal-kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ pub type IOOptionBits = u32;
pub type CFNumberType = u32;

pub const kIOMainPortDefault: mach_port_t = 0;
pub const kIOServicePlane: *const c_char =
b"IOService\x00" as *const [u8; 10usize] as *const c_char;
pub const MACH_PORT_NULL: i32 = 0;
pub const kIOServicePlane: &str = "IOService\0";
pub const kCFNumberSInt64Type: CFNumberType = 4;

pub const MACH_PORT_NULL: i32 = 0;

#[link(name = "IOKit", kind = "framework")]
extern "C" {
pub fn IOServiceGetMatchingServices(
Expand All @@ -40,7 +40,7 @@ extern "C" {

pub fn IORegistryEntrySearchCFProperty(
entry: io_registry_entry_t,
plane: *mut c_char,
plane: *const c_char,
key: CFStringRef,
allocator: CFAllocatorRef,
options: IOOptionBits,
Expand All @@ -60,6 +60,6 @@ extern "C" {
fn __CFStringMakeConstantString(c_str: *const c_char) -> CFStringRef;
}

pub fn cfstr(c_str: *const c_char) -> CFStringRef {
unsafe { __CFStringMakeConstantString(c_str) }
pub fn cfstr(val: &str) -> CFStringRef {
unsafe { __CFStringMakeConstantString(val.as_ptr().cast()) }
}
122 changes: 122 additions & 0 deletions candle-metal-kernels/src/gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use core::ffi::c_void;
use metal::Device;

use crate::ffi::*;

const GPU_CORE_COUNT_KEY: &str = "gpu-core-count\0";
const AGXACCELERATOR_KEY: &str = "AGXAccelerator\0";

struct IOIterator(io_iterator_t);

impl IOIterator {
fn new(it: io_iterator_t) -> Self {
IOIterator(it)
}

fn next(&self) -> Option<io_object_t> {
let result = unsafe { IOIteratorNext(self.0) };
if result == MACH_PORT_NULL as u32 {
return None;
}
Some(result)
}
}

impl Drop for IOIterator {
fn drop(&mut self) {
unsafe { IOObjectRelease(self.0 as _) };
}
}

unsafe fn get_io_service_matching(val: &str) -> Result<CFMutableDictionaryRef, String> {
let matching = IOServiceMatching(val.as_ptr().cast());
if matching.is_null() {
return Err(format!("IOServiceMatching call failed, `{val}` not found"));
}
Ok(matching)
}

unsafe fn get_matching_services(
main_port: mach_port_t,
matching: CFMutableDictionaryRef,
) -> Result<IOIterator, String> {
let mut iterator: io_iterator_t = 0;
let result = IOServiceGetMatchingServices(main_port, matching, &mut iterator);
if result != 0 {
return Err("Error getting matching services".to_string());
}
Ok(IOIterator::new(iterator))
}

unsafe fn get_gpu_io_service() -> Result<io_object_t, String> {
let matching = get_io_service_matching(AGXACCELERATOR_KEY)?;
let iterator = get_matching_services(kIOMainPortDefault, matching)?;
iterator
.next()
.ok_or("Error getting GPU IO Service".to_string())
}

unsafe fn get_property_by_key(
entry: io_registry_entry_t,
plane: &str,
key: &str,
allocator: CFAllocatorRef,
options: IOOptionBits,
) -> Result<CFTypeRef, String> {
let result = IORegistryEntrySearchCFProperty(
entry,
plane.as_ptr().cast(),
cfstr(key),
allocator,
options,
);
if result.is_null() {
return Err(format!("Error getting {key} property"));
}
Ok(result)
}

unsafe fn get_int_value(number: CFNumberRef) -> Result<i64, String> {
let mut value: i64 = 0;
let result = CFNumberGetValue(
number,
kCFNumberSInt64Type,
&mut value as *mut i64 as *mut c_void,
);
if !result {
return Err("Error getting int value".to_string());
}
Ok(value)
}

unsafe fn find_core_count() -> Result<usize, String> {
let gpu_io_service = get_gpu_io_service()?;
let gpu_core_count = get_property_by_key(
gpu_io_service,
kIOServicePlane,
GPU_CORE_COUNT_KEY,
core::ptr::null(),
0,
)?;
let value = get_int_value(gpu_core_count as CFNumberRef)?;
Ok(value as usize)
}

pub(crate) fn get_device_core_count(device: &Device) -> usize {
#[cfg(target_os = "macos")]
{
unsafe { find_core_count().expect("Retrieving gpu core count failed") }
}
#[cfg(target_os = "ios")]
{
if device.name().starts_with("A") {
if device.supports_family(MTLGPUFamily::Apple9) {
6
} else {
5
}
} else {
10
}
}
}
77 changes: 0 additions & 77 deletions candle-metal-kernels/src/gpu_info.rs

This file was deleted.

5 changes: 3 additions & 2 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;

#[cfg(target_os = "macos")]
mod ffi;
mod gpu_info;
use gpu_info::get_device_core_count;
mod gpu;
use gpu::get_device_core_count;

mod utils;
pub use utils::BufferOffset;
Expand Down

0 comments on commit 2a2a349

Please sign in to comment.