From b355ab4e2e52b077e71aac46c286fbce033f36d6 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Tue, 17 Oct 2023 11:57:12 +0200 Subject: [PATCH] Always broadcast magic methods (#1101) --- candle-pyo3/src/lib.rs | 8 +-- candle-pyo3/tests/native/test_tensor.py | 73 +++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 02db05e568..55b20308a7 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -536,7 +536,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 + &rhs.0).map_err(wrap_err)? + self.0.broadcast_add(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 + rhs).map_err(wrap_err)? } else { @@ -553,7 +553,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __mul__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 * &rhs.0).map_err(wrap_err)? + self.0.broadcast_mul(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 * rhs).map_err(wrap_err)? } else { @@ -570,7 +570,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __sub__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 - &rhs.0).map_err(wrap_err)? + self.0.broadcast_sub(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 - rhs).map_err(wrap_err)? } else { @@ -583,7 +583,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __truediv__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 / &rhs.0).map_err(wrap_err)? + self.0.broadcast_div(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 / rhs).map_err(wrap_err)? } else { diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 1f5b74f677..225a746996 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,5 +1,6 @@ import candle from candle import Tensor +import pytest def test_tensor_can_be_constructed(): @@ -72,3 +73,75 @@ def test_tensor_can_be_scliced_3d(): assert t[:, 0, 0].values() == [1, 9] assert t[..., 0].values() == [[1, 5], [9, 13]] assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] + + +def test_tensor_can_be_added(): + t = Tensor(42.0) + result = t + t + assert result.values() == 84.0 + result = t + 2.0 + assert result.values() == 44.0 + a = candle.rand((3, 1, 4)) + b = candle.rand((2, 1)) + c_native = a.broadcast_add(b) + c = a + b + assert c.shape == (3, 2, 4) + assert c.values() == c_native.values() + with pytest.raises(ValueError): + d = candle.rand((3, 4, 5)) + e = candle.rand((4, 6)) + f = d + e + + +def test_tensor_can_be_subtracted(): + t = Tensor(42.0) + result = t - t + assert result.values() == 0 + result = t - 2.0 + assert result.values() == 40.0 + a = candle.rand((3, 1, 4)) + b = candle.rand((2, 1)) + c_native = a.broadcast_sub(b) + c = a - b + assert c.shape == (3, 2, 4) + assert c.values() == c_native.values() + with pytest.raises(ValueError): + d = candle.rand((3, 4, 5)) + e = candle.rand((4, 6)) + f = d - e + + +def test_tensor_can_be_multiplied(): + t = Tensor(42.0) + result = t * t + assert result.values() == 1764.0 + result = t * 2.0 + assert result.values() == 84.0 + a = candle.rand((3, 1, 4)) + b = candle.rand((2, 1)) + c_native = a.broadcast_mul(b) + c = a * b + assert c.shape == (3, 2, 4) + assert c.values() == c_native.values() + with pytest.raises(ValueError): + d = candle.rand((3, 4, 5)) + e = candle.rand((4, 6)) + f = d * e + + +def test_tensor_can_be_divided(): + t = Tensor(42.0) + result = t / t + assert result.values() == 1.0 + result = t / 2.0 + assert result.values() == 21.0 + a = candle.rand((3, 1, 4)) + b = candle.rand((2, 1)) + c_native = a.broadcast_div(b) + c = a / b + assert c.shape == (3, 2, 4) + assert c.values() == c_native.values() + with pytest.raises(ValueError): + d = candle.rand((3, 4, 5)) + e = candle.rand((4, 6)) + f = d / e