Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add PolarsAllocator #88

Merged
merged 6 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@ pyo3 = { version = "0.21", features = ["abi3-py38"] }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
rayon = "1.7.0"
serde = { version = "1", features = ["derive"] }

[target.'cfg(target_os = "linux")'.dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
8 changes: 3 additions & 5 deletions example/derive_expression/expression_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use pyo3_polars::PolarsAllocator;

mod distances;
mod expressions;

#[cfg(target_os = "linux")]
use jemallocator::Jemalloc;

#[global_allocator]
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;
static ALLOC: PolarsAllocator = PolarsAllocator::new();
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use polars_lazy::frame::IntoLazy;
use polars_lazy::prelude::LazyFrame;
use pyo3::prelude::*;
use pyo3_polars::error::PyPolarsErr;
use pyo3_polars::{PyDataFrame, PyLazyFrame};
use pyo3_polars::{PolarsAllocator, PyDataFrame, PyLazyFrame};

#[global_allocator]
static ALLOC: PolarsAllocator = PolarsAllocator::new();

#[pyfunction]
fn parallel_jaccard(pydf: PyDataFrame, col_a: &str, col_b: &str) -> PyResult<PyDataFrame> {
Expand Down
2 changes: 2 additions & 0 deletions pyo3-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ description = "Expression plugins and PyO3 types for polars"

[dependencies]
ciborium = { version = "0.2.1", optional = true }
libc = "0.2" # pyo3 depends on libc already, so this does not introduce an extra dependence.
once_cell = "1"
polars = { workspace = true, default-features = false }
polars-core = { workspace = true, default-features = false }
polars-ffi = { workspace = true, optional = true }
Expand Down
115 changes: 115 additions & 0 deletions pyo3-polars/src/alloc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::alloc::{GlobalAlloc, Layout, System};
use std::ffi::c_char;

use once_cell::race::OnceRef;
use pyo3::ffi::{PyCapsule_Import, Py_IsInitialized};
use pyo3::Python;

unsafe extern "C" fn fallback_alloc(size: usize, align: usize) -> *mut u8 {
System.alloc(Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_dealloc(ptr: *mut u8, size: usize, align: usize) {
System.dealloc(ptr, Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_alloc_zeroed(size: usize, align: usize) -> *mut u8 {
System.alloc_zeroed(Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_realloc(
ptr: *mut u8,
size: usize,
align: usize,
new_size: usize,
) -> *mut u8 {
System.realloc(
ptr,
Layout::from_size_align_unchecked(size, align),
new_size,
)
}

#[repr(C)]
struct AllocatorCapsule {
alloc: unsafe extern "C" fn(usize, usize) -> *mut u8,
dealloc: unsafe extern "C" fn(*mut u8, usize, usize),
alloc_zeroed: unsafe extern "C" fn(usize, usize) -> *mut u8,
realloc: unsafe extern "C" fn(*mut u8, usize, usize, usize) -> *mut u8,
}

static FALLBACK_ALLOCATOR_CAPSULE: AllocatorCapsule = AllocatorCapsule {
alloc: fallback_alloc,
alloc_zeroed: fallback_alloc_zeroed,
dealloc: fallback_dealloc,
realloc: fallback_realloc,
};

static ALLOCATOR_CAPSULE_NAME: &[u8] = b"polars.polars._allocator\0";

/// A memory allocator that relays allocations to the allocator used by Polars.
///
/// You can use it as the global memory allocator:
///
/// ```rust
/// use pyo3_polars::PolarsAllocator;
///
/// #[global_allocator]
/// static ALLOC: PolarsAllocator = PolarsAllocator::new();
/// ```
///
/// If the allocator capsule (`polars.polars._allocator`) is not available,
/// this allocator fallbacks to [`std::alloc::System`].
pub struct PolarsAllocator(OnceRef<'static, AllocatorCapsule>);

impl PolarsAllocator {
fn get_allocator(&self) -> &'static AllocatorCapsule {
// Do not allocate in this function,
// otherwise it will cause infinite recursion.
self.0.get_or_init(|| {
let r = (unsafe { Py_IsInitialized() } != 0)
.then(|| {
Python::with_gil(|_| unsafe {
(PyCapsule_Import(ALLOCATOR_CAPSULE_NAME.as_ptr() as *const c_char, 0)
as *const AllocatorCapsule)
.as_ref()
})
})
.flatten();
#[cfg(debug_assertions)]
if r.is_none() {
// Do not use eprintln; it may alloc.
let msg = b"failed to get allocator capsule\n";
unsafe { libc::write(2, msg.as_ptr() as *const libc::c_void, msg.len()) };
}
r.unwrap_or(&FALLBACK_ALLOCATOR_CAPSULE)
})
}

/// Create a `PolarsAllocator`.
pub const fn new() -> Self {
PolarsAllocator(OnceRef::new())
}
}

unsafe impl GlobalAlloc for PolarsAllocator {
#[inline]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
(self.get_allocator().alloc)(layout.size(), layout.align())
}

#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
(self.get_allocator().dealloc)(ptr, layout.size(), layout.align());
}

#[inline]
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
(self.get_allocator().alloc_zeroed)(layout.size(), layout.align())
}

#[inline]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
(self.get_allocator().realloc)(ptr, layout.size(), layout.align(), new_size)
}
}
2 changes: 2 additions & 0 deletions pyo3-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@
//! })
//! out_df = my_cool_function(df)
//! ```
mod alloc;
#[cfg(feature = "derive")]
pub mod derive;
pub mod error;
#[cfg(feature = "derive")]
pub mod export;
mod ffi;

pub use crate::alloc::PolarsAllocator;
use crate::error::PyPolarsErr;
use crate::ffi::to_py::to_py_array;
use polars::export::arrow;
Expand Down
Loading