From caf42e99872cc34e1dac2f549c4bfa94093c7484 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:33:37 +0200 Subject: [PATCH 1/2] Add proper `None` and `tensor` indexing --- candle-pyo3/src/lib.rs | 47 ++++++++++++++++++++++--- candle-pyo3/tests/native/test_tensor.py | 19 ++++++++++ 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 02db05e568..f988b22fe5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -201,6 +201,8 @@ enum Indexer { Index(usize), Slice(usize, usize), Elipsis, + Expand, + IndexSelect(Tensor), } #[pymethods] @@ -475,27 +477,48 @@ impl PyTensor { // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] let index = slice.indices(dims[0] as c_long)?; indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + } else if let Ok(tensor) = idx.extract::(py) { + // Handle a tensor as indices e.g. tensor[tensor([0,1])] + let t = tensor.0; + if t.rank() != 1 { + return Err(PyTypeError::new_err( + "multi-dimensional tensor indexing is not supported", + )); + } + indexers.push(Indexer::IndexSelect(t)); } else if let Ok(tuple) = idx.downcast::(py) { // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1] - if tuple.len() > dims.len() { + let mut not_none_count = 0; + for item in tuple.iter() { + if !item.is_none() { + not_none_count += 1; + } + } + if not_none_count > dims.len() { return Err(PyTypeError::new_err("provided too many indices")); } + let mut current_dim = 0; for (i, item) in tuple.iter().enumerate() { if item.is_ellipsis() { // Handle '...' e.g. tensor[..., 0] - if i > 0 { return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation")); } indexers.push(Indexer::Elipsis); + current_dim = dims.len() - (tuple.len() - 1); + } else if item.is_none() { + // Handle None e.g. tensor[None, 0] + indexers.push(Indexer::Expand); } else if let Ok(slice) = item.downcast::() { // Handle slice - let index = slice.indices(dims[i] as c_long)?; + let index = slice.indices(dims[current_dim] as c_long)?; indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + current_dim += 1; } else if let Ok(index) = item.extract::() { - indexers.push(Indexer::Index(to_absolute_index(index, i)?)); + indexers.push(Indexer::Index(to_absolute_index(index, current_dim)?)); + current_dim += 1; } else { return Err(PyTypeError::new_err("unsupported index")); } @@ -526,6 +549,22 @@ impl PyTensor { current_dim += dims.len() - (indexers.len() - 1); x } + Indexer::Expand => { + // Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim + let out = x.unsqueeze(current_dim).map_err(wrap_err)?; + current_dim += 1; + out + } + Indexer::IndexSelect(indexes) => { + let out = x + .index_select( + &indexes.to_device(x.device()).map_err(wrap_err)?, + current_dim, + ) + .map_err(wrap_err)?; + current_dim += 1; + out + } } } diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 1f5b74f677..ab147417a4 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -72,3 +72,22 @@ def test_tensor_can_be_scliced_3d(): assert t[:, 0, 0].values() == [1, 9] assert t[..., 0].values() == [[1, 5], [9, 13]] assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] + + +def test_tensor_can_be_expanded_with_none(): + t = candle.rand((12, 12)) + c = t[:, None, None, :] + assert c.shape == (12, 1, 1, 12) + d = t[None, :, None, :] + assert d.shape == (1, 12, 1, 12) + e = t[None, None, :, :] + assert e.shape == (1, 1, 12, 12) + f = t[:, :, None] + assert f.shape == (12, 12, 1) + + +def test_tensor_can_be_index_via_tensor(): + t = candle.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + c = t[candle.Tensor([0, 2])] + assert c.shape == (2, 4) + assert c.values() == [[1, 2, 3, 4], [9, 10, 11, 12]] From 50fba7f98bc8225dfee9010b0466bb0338d47184 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:34:45 +0200 Subject: [PATCH 2/2] Allow indexing via lists + allow tensor/list indexing outside of first dimension --- candle-pyo3/src/lib.rs | 121 ++++++++++++++---------- candle-pyo3/tests/native/test_tensor.py | 31 ++++-- 2 files changed, 97 insertions(+), 55 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 86c283d317..f16d8c1b18 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -452,7 +452,7 @@ impl PyTensor { let mut indexers: Vec = vec![]; let dims = self.0.shape().dims(); - let to_absolute_index = |index: isize, current_dim: usize| { + fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult { // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0] let actual_index = if index < 0 { dims[current_dim] as isize + index @@ -462,69 +462,92 @@ impl PyTensor { // Check that the index is in range if actual_index < 0 || actual_index >= dims[current_dim] as isize { - return Err(PyTypeError::new_err(format!( + return Err(PyValueError::new_err(format!( "index out of range for dimension '{i}' with indexer '{value}'", i = current_dim, value = index ))); } Ok(actual_index as usize) - }; - if let Ok(index) = idx.extract(py) { - // Handle a single index e.g. tensor[0] or tensor[-1] - indexers.push(Indexer::Index(to_absolute_index(index, 0)?)); - } else if let Ok(slice) = idx.downcast::(py) { - // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] - let index = slice.indices(dims[0] as c_long)?; - indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); - } else if let Ok(tensor) = idx.extract::(py) { - // Handle a tensor as indices e.g. tensor[tensor([0,1])] - let t = tensor.0; - if t.rank() != 1 { - return Err(PyTypeError::new_err( - "multi-dimensional tensor indexing is not supported", - )); - } - indexers.push(Indexer::IndexSelect(t)); - } else if let Ok(tuple) = idx.downcast::(py) { - // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1] + } - let mut not_none_count = 0; - for item in tuple.iter() { - if !item.is_none() { - not_none_count += 1; + fn extract_indexer( + py_indexer: &PyAny, + current_dim: usize, + dims: &[usize], + index_argument_count: usize, + ) -> PyResult<(Indexer, usize)> { + if let Ok(index) = py_indexer.extract() { + // Handle a single index e.g. tensor[0] or tensor[-1] + Ok(( + Indexer::Index(to_absolute_index(index, current_dim, dims)?), + current_dim + 1, + )) + } else if let Ok(slice) = py_indexer.downcast::() { + // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] + let index = slice.indices(dims[current_dim] as c_long)?; + Ok(( + Indexer::Slice(index.start as usize, index.stop as usize), + current_dim + 1, + )) + } else if let Ok(tensor) = py_indexer.extract::() { + // Handle a tensor as indices e.g. tensor[tensor([0,1])] + let t = tensor.0; + if t.rank() != 1 { + return Err(PyTypeError::new_err( + "multi-dimensional tensor indexing is not supported", + )); + } + Ok((Indexer::IndexSelect(t), current_dim + 1)) + } else if let Ok(list) = py_indexer.downcast::() { + // Handle a list of indices e.g. tensor[[0,1]] + let mut indexes = vec![]; + for item in list.iter() { + let index = item.extract::()?; + indexes.push(index); } + Ok(( + Indexer::IndexSelect( + Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?, + ), + current_dim + 1, + )) + } else if py_indexer.is_ellipsis() { + // Handle '...' e.g. tensor[..., 0] + if current_dim > 0 { + return Err(PyTypeError::new_err( + "Ellipsis ('...') can only be used at the start of an indexing operation", + )); + } + Ok((Indexer::Elipsis, dims.len() - (index_argument_count - 1))) + } else if py_indexer.is_none() { + // Handle None e.g. tensor[None, 0] + Ok((Indexer::Expand, current_dim)) + } else { + Err(PyTypeError::new_err(format!( + "unsupported indexer {}", + py_indexer + ))) } + } + + if let Ok(tuple) = idx.downcast::(py) { + let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count(); + if not_none_count > dims.len() { - return Err(PyTypeError::new_err("provided too many indices")); + return Err(PyValueError::new_err("provided too many indices")); } let mut current_dim = 0; - for (i, item) in tuple.iter().enumerate() { - if item.is_ellipsis() { - // Handle '...' e.g. tensor[..., 0] - if i > 0 { - return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation")); - } - indexers.push(Indexer::Elipsis); - current_dim = dims.len() - (tuple.len() - 1); - } else if item.is_none() { - // Handle None e.g. tensor[None, 0] - indexers.push(Indexer::Expand); - } else if let Ok(slice) = item.downcast::() { - // Handle slice - let index = slice.indices(dims[current_dim] as c_long)?; - indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); - current_dim += 1; - } else if let Ok(index) = item.extract::() { - indexers.push(Indexer::Index(to_absolute_index(index, current_dim)?)); - current_dim += 1; - } else { - return Err(PyTypeError::new_err("unsupported index")); - } + for item in tuple.iter() { + let (indexer, new_current_dim) = + extract_indexer(item, current_dim, dims, not_none_count)?; + current_dim = new_current_dim; + indexers.push(indexer); } } else { - return Err(PyTypeError::new_err("unsupported index")); + let (indexer, _) = extract_indexer(idx.downcast::(py)?, 0, dims, 1)?; + indexers.push(indexer); } let mut x = self.0.clone(); diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index cc8fdcb2ff..e4cf19f1a9 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -55,6 +55,7 @@ def test_tensor_can_be_sliced(): assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0] assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0] assert t[-4:-2].values() == [5.0, 9.0] + assert t[...].values() == t.values() def test_tensor_can_be_sliced_2d(): @@ -78,6 +79,9 @@ def test_tensor_can_be_scliced_3d(): def test_tensor_can_be_expanded_with_none(): t = candle.rand((12, 12)) + + b = t[None] + assert b.shape == (1, 12, 12) c = t[:, None, None, :] assert c.shape == (12, 1, 1, 12) d = t[None, :, None, :] @@ -89,12 +93,27 @@ def test_tensor_can_be_expanded_with_none(): def test_tensor_can_be_index_via_tensor(): - t = candle.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) - c = t[candle.Tensor([0, 2])] - assert c.shape == (2, 4) - assert c.values() == [[1, 2, 3, 4], [9, 10, 11, 12]] - - + t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]]) + indexed = t[candle.Tensor([0, 2])] + assert indexed.shape == (2, 4) + assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]] + + indexed = t[:, candle.Tensor([0, 2])] + assert indexed.shape == (3, 2) + assert indexed.values() == [[1, 1], [3, 3], [5, 5]] + + +def test_tensor_can_be_index_via_list(): + t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]]) + indexed = t[[0, 2]] + assert indexed.shape == (2, 4) + assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]] + + indexed = t[:, [0, 2]] + assert indexed.shape == (3, 2) + assert indexed.values() == [[1, 1], [3, 3], [5, 5]] + + def test_tensor_can_be_cast_via_to(): t = Tensor(42.0) assert str(t.dtype) == str(candle.f32)