Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3: Add equal and __richcmp__ to candle.Tensor #1099

Merged
Merged
2 changes: 1 addition & 1 deletion candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl Shape {

/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
Expand Down
36 changes: 36 additions & 0 deletions candle-pyo3/_additional_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Te
Return a slice of a tensor.
"""
pass

def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
41 changes: 41 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
Expand All @@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
Expand Down Expand Up @@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
Expand Down
70 changes: 70 additions & 0 deletions candle-pyo3/py_src/candle/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import candle
from candle import Tensor


_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])


def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"

if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"

if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"

if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"


def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"


def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.

Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

# Secure against overflow of u32 and u8 tensors
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
actual = actual.to(candle.i64)
expected = expected.to(candle.i64)

diff = (actual - expected).abs()

threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
73 changes: 71 additions & 2 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;

Expand Down Expand Up @@ -132,9 +135,10 @@ macro_rules! pydtype {
}
};
}

pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
Expand Down Expand Up @@ -317,6 +321,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}

#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelement(&self) -> usize {
self.0.elem_count()
}

#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
Expand Down Expand Up @@ -353,6 +364,12 @@ impl PyTensor {
self.__repr__()
}

/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
}

/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Expand Down Expand Up @@ -670,6 +687,58 @@ impl PyTensor {
};
Ok(Self(tensor))
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
CompareOp::Ne => lhs.ne(rhs),
CompareOp::Lt => lhs.lt(rhs),
CompareOp::Le => lhs.le(rhs),
CompareOp::Gt => lhs.gt(rhs),
CompareOp::Ge => lhs.ge(rhs),
};
Ok(PyTensor(t.map_err(wrap_err)?))
};
if let Ok(rhs) = rhs.extract::<PyTensor>() {
if self.0.shape() == rhs.0.shape() {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = self
.0
.shape()
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
.map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;

compare(&broadcasted_lhs, &broadcasted_rhs)
}
} else if let Ok(rhs) = rhs.extract::<f64>() {
let scalar_tensor = Tensor::new(rhs, self.0.device())
.map_err(wrap_err)?
.to_dtype(self.0.dtype())
.map_err(wrap_err)?
.broadcast_as(self.0.shape())
.map_err(wrap_err)?;

compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
}
}

fn __hash__(&self) -> u64 {
// we have overridden __richcmp__ => py03 wants us to also override __hash__
// we simply hash the address of the tensor
let mut hasher = DefaultHasher::new();
let pointer = &self.0 as *const Tensor;
let address = pointer as usize;
address.hash(&mut hasher);
hasher.finish()
}

#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
Expand Down Expand Up @@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("i16", PyDType(DType::I64))?;
m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;
Expand Down
33 changes: 33 additions & 0 deletions candle-pyo3/tests/bindings/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)

with pytest.raises(AssertionError):
assert_equal(a, b + 1)


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)

assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)
Loading
Loading