From 91ab746da4822b7059b23349e70d87aa80dc6c9e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 6 Oct 2024 10:09:38 +0200 Subject: [PATCH] pyo3 update. (#2545) * pyo3 update. * Stub fix. --- candle-examples/Cargo.toml | 4 ++-- candle-pyo3/Cargo.toml | 4 ++-- candle-pyo3/py_src/candle/utils/__init__.pyi | 10 +++------- candle-pyo3/src/lib.rs | 19 +++++++++---------- candle-pyo3/src/shape.rs | 12 ++++++------ 5 files changed, 22 insertions(+), 27 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 4edde7a966..0c1219d760 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } palette = { version = "0.7.6", optional = true } enterpolation = { version = "0.2.1", optional = true} -pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } safetensors = { workspace = true } @@ -121,4 +121,4 @@ required-features = ["onnx"] [[example]] name = "colpali" -required-features = ["pdf2image"] \ No newline at end of file +required-features = ["pdf2image"] diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 8800133429..2776a3f77c 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -20,10 +20,10 @@ candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] -pyo3-build-config = "0.21" +pyo3-build-config = "0.22" [features] default = [] diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index c9a9f9f3c1..94c3228398 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -33,9 +33,7 @@ def has_mkl() -> bool: pass @staticmethod -def load_ggml( - path: Union[str, PathLike], device: Optional[Device] = None -) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: +def load_ggml(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -43,9 +41,7 @@ def load_ggml( pass @staticmethod -def load_gguf( - path: Union[str, PathLike], device: Optional[Device] = None -) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: +def load_gguf(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. @@ -60,7 +56,7 @@ def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]: pass @staticmethod -def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]): +def save_gguf(path, tensors, metadata): """ Save quanitzed tensors and metadata to a GGUF file. """ diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 0da2c70028..722b5e3ace 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -6,7 +6,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; -use std::os::raw::c_long; use std::sync::Arc; use half::{bf16, f16}; @@ -115,7 +114,7 @@ impl PyDevice { } impl<'source> FromPyObject<'source> for PyDevice { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { let device: String = ob.extract()?; let device = match device.as_str() { "cpu" => PyDevice::Cpu, @@ -217,11 +216,11 @@ enum Indexer { IndexSelect(Tensor), } -#[derive(Clone, Debug)] +#[derive(Debug)] struct TorchTensor(PyObject); impl<'source> pyo3::FromPyObject<'source> for TorchTensor { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?; Ok(TorchTensor(numpy_value)) } @@ -540,7 +539,7 @@ impl PyTensor { )) } else if let Ok(slice) = py_indexer.downcast::() { // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] - let index = slice.indices(dims[current_dim] as c_long)?; + let index = slice.indices(dims[current_dim] as isize)?; Ok(( Indexer::Slice(index.start as usize, index.stop as usize), current_dim + 1, @@ -1284,7 +1283,7 @@ fn save_safetensors( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] +#[pyo3(signature = (path, device = None))] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] @@ -1325,7 +1324,7 @@ fn load_ggml( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] +#[pyo3(signature = (path, device = None))] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] @@ -1384,7 +1383,7 @@ fn load_gguf( #[pyfunction] #[pyo3( - text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])" + signature = (path, tensors, metadata) )] /// Save quanitzed tensors and metadata to a GGUF file. fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> { @@ -1430,7 +1429,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) Ok(v) } let tensors = tensors - .extract::<&PyDict>(py) + .downcast_bound::(py) .map_err(|_| PyErr::new::("expected a dict"))? .iter() .map(|(key, value)| { @@ -1443,7 +1442,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) .collect::>>()?; let metadata = metadata - .extract::<&PyDict>(py) + .downcast_bound::(py) .map_err(|_| PyErr::new::("expected a dict"))? .iter() .map(|(key, value)| { diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs index 2668b7331b..b9bc67899d 100644 --- a/candle-pyo3/src/shape.rs +++ b/candle-pyo3/src/shape.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; pub struct PyShape(Vec); impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { if ob.is_none() { return Err(PyErr::new::( "Shape cannot be None", @@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape { let tuple = ob.downcast::()?; if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - let dims: Vec = pyo3::FromPyObject::extract(first_element)?; + let dims: Vec = pyo3::FromPyObject::extract_bound(&first_element)?; Ok(PyShape(dims)) } else { - let dims: Vec = pyo3::FromPyObject::extract(tuple)?; + let dims: Vec = pyo3::FromPyObject::extract_bound(tuple)?; Ok(PyShape(dims)) } } @@ -36,7 +36,7 @@ impl From for ::candle::Shape { pub struct PyShapeWithHole(Vec); impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { if ob.is_none() { return Err(PyErr::new::( "Shape cannot be None", @@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { let tuple = ob.downcast::()?; let dims: Vec = if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - pyo3::FromPyObject::extract(first_element)? + pyo3::FromPyObject::extract_bound(&first_element)? } else { - pyo3::FromPyObject::extract(tuple)? + pyo3::FromPyObject::extract_bound(tuple)? }; // Ensure we have only positive numbers and at most one "hole" (-1)