From f79281a567a115a7108c98b46f2db5288a4a2ce4 Mon Sep 17 00:00:00 2001 From: GaoSQ <13833117950@163.com> Date: Thu, 27 Apr 2023 23:47:44 +0800 Subject: [PATCH 1/2] add transpose specify axes --- tenseal/binding.cpp | 16 ++++++++++++---- tenseal/cpp/tensors/ckkstensor.cpp | 8 ++++++++ tenseal/cpp/tensors/ckkstensor.h | 2 ++ tenseal/cpp/tensors/plain_tensor.h | 7 +++++++ tenseal/cpp/tensors/tensor_storage.h | 10 ++++++++++ tenseal/tensors/ckkstensor.py | 19 +++++++++++++++---- tenseal/tensors/plaintensor.py | 17 +++++++++++++---- 7 files changed, 67 insertions(+), 12 deletions(-) diff --git a/tenseal/binding.cpp b/tenseal/binding.cpp index 0254ad31..0867ee7a 100644 --- a/tenseal/binding.cpp +++ b/tenseal/binding.cpp @@ -143,8 +143,12 @@ void bind_plain_tensor(py::module &m, const std::string &name) { .def("replicate", &type::replicate) .def("broadcast", &type::broadcast) .def("broadcast_", &type::broadcast_inplace) - .def("transpose", &type::transpose) - .def("transpose_", &type::transpose_inplace) + .def("transpose", py::overload_cast<>(&type::transpose, py::const_)) + .def("transpose_", py::overload_cast<>(&type::transpose_inplace)) + .def("transpose", py::overload_cast &>( + &type::transpose, py::const_)) + .def("transpose_", py::overload_cast &>( + &type::transpose_inplace)) .def("serialize", [](type &obj) { return py::bytes(obj.save()); }); } @@ -709,8 +713,12 @@ void bind_ckks_tensor(py::module &m) { .def("reshape_", &CKKSTensor::reshape_inplace) .def("broadcast", &CKKSTensor::broadcast) .def("broadcast_", &CKKSTensor::broadcast_inplace) - .def("transpose", &CKKSTensor::transpose) - .def("transpose_", &CKKSTensor::transpose_inplace) + .def("transpose", py::overload_cast<>(&CKKSTensor::transpose, py::const_)) + .def("transpose_", py::overload_cast<>(&CKKSTensor::transpose_inplace)) + .def("transpose", py::overload_cast &>( + &CKKSTensor::transpose, py::const_)) + .def("transpose_", py::overload_cast &>( + &CKKSTensor::transpose_inplace)) .def("scale", &CKKSTensor::scale); } diff --git a/tenseal/cpp/tensors/ckkstensor.cpp b/tenseal/cpp/tensors/ckkstensor.cpp index 8390822e..2135596b 100644 --- a/tenseal/cpp/tensors/ckkstensor.cpp +++ b/tenseal/cpp/tensors/ckkstensor.cpp @@ -764,6 +764,14 @@ shared_ptr CKKSTensor::transpose_inplace() { return shared_from_this(); } +shared_ptr CKKSTensor::transpose(const vector& permutation) const { + return this->copy()->transpose_inplace(permutation); +} +shared_ptr CKKSTensor::transpose_inplace(const vector& permutation) { + this->_data.transpose_inplace(permutation); + + return shared_from_this(); +} double CKKSTensor::scale() const { return _init_scale; } } // namespace tenseal diff --git a/tenseal/cpp/tensors/ckkstensor.h b/tenseal/cpp/tensors/ckkstensor.h index 80b0e2e7..6809b488 100644 --- a/tenseal/cpp/tensors/ckkstensor.h +++ b/tenseal/cpp/tensors/ckkstensor.h @@ -114,6 +114,8 @@ class CKKSTensor : public EncryptedTensor>, shared_ptr transpose() const; shared_ptr transpose_inplace(); + shared_ptr transpose(const vector& permutation) const; + shared_ptr transpose_inplace(const vector& permutation); vector shape_with_batch() const; double scale() const override; diff --git a/tenseal/cpp/tensors/plain_tensor.h b/tenseal/cpp/tensors/plain_tensor.h index 9d798e20..c49199de 100644 --- a/tenseal/cpp/tensors/plain_tensor.h +++ b/tenseal/cpp/tensors/plain_tensor.h @@ -85,6 +85,13 @@ class PlainTensor { this->_data.transpose_inplace(); return *this; } + PlainTensor transpose(const vector& permutation) const { + return this->copy().transpose_inplace(permutation); + } + PlainTensor& transpose_inplace(const vector& permutation) { + this->_data.transpose_inplace(permutation); + return *this; + } /** * Returns the element at position {idx1, idx2, ..., idxn} in the current * shape diff --git a/tenseal/cpp/tensors/tensor_storage.h b/tenseal/cpp/tensors/tensor_storage.h index 84fc2169..955e0197 100644 --- a/tenseal/cpp/tensors/tensor_storage.h +++ b/tenseal/cpp/tensors/tensor_storage.h @@ -162,6 +162,16 @@ class TensorStorage { this->_data = xt::transpose(this->_data); return *this; } + + TensorStorage transpose(const vector& axes) const { + return this->copy().transpose_inplace(axes); + } + + TensorStorage transpose_inplace(const vector& axes) { + this->_data = xt::transpose(this->_data, axes); + return *this; + } + /** * Returns the element at position {idx1, idx2, ..., idxn} in the current * shape diff --git a/tenseal/tensors/ckkstensor.py b/tenseal/tensors/ckkstensor.py index 4a918c47..6292a5c2 100644 --- a/tenseal/tensors/ckkstensor.py +++ b/tenseal/tensors/ckkstensor.py @@ -156,12 +156,23 @@ def broadcast_(self, shape: List[int]): self.data.broadcast_(shape) return self - def transpose(self): + def transpose(self, axes: List[int] = None) -> "CKKSTensor": "Copies the transpose to a new tensor" - result = self.data.transpose() + result = None + if axes is None: + result = self.data.transpose() + elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): + result = self.data.transpose(axes) + else: + raise TypeError("axes must be a list of integers") return self._wrap(result) - def transpose_(self): + def transpose_(self, axes: List[int] = None) -> "CKKSTensor": "Tries to transpose the tensor" - self.data.transpose_() + if axes is None: + self.data.transpose_() + elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): + self.data.transpose_(axes) + else: + raise TypeError("axes must be a list of integers") return self diff --git a/tenseal/tensors/plaintensor.py b/tenseal/tensors/plaintensor.py index fbf24622..239fb765 100644 --- a/tenseal/tensors/plaintensor.py +++ b/tenseal/tensors/plaintensor.py @@ -126,14 +126,23 @@ def broadcast_(self, shape: List[int]): self.data.broadcast_(shape) return self - def transpose(self): + def transpose(self, axes: List[int] = None): "Copies the transpose to a new tensor" - new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) + new_tensor = None + if axes is None: + new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) + elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): + new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) + else: + raise TypeError("axes must be a list of integers") return new_tensor.transpose_() - def transpose_(self): + def transpose_(self, axes: List[int] = None): "Tries to transpose the tensor" - self.data.transpose_() + if axes is None: + self.data.transpose_() + elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): + self.data.transpose_(axes) return self @classmethod From bb43db65161651599e908c439c9ed5a218b35300 Mon Sep 17 00:00:00 2001 From: Shiqi Gao Date: Fri, 28 Apr 2023 04:09:26 +0000 Subject: [PATCH 2/2] add transpose test --- tenseal/tensors/plaintensor.py | 11 ++++---- tests/cpp/tensors/ckkstensor_test.cpp | 28 +++++++++++++++++++ .../tenseal/tensors/test_ckks_tensor.py | 24 ++++++++++++++++ .../tenseal/tensors/test_plain_tensor.py | 22 +++++++++++++++ 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/tenseal/tensors/plaintensor.py b/tenseal/tensors/plaintensor.py index 239fb765..083de9a3 100644 --- a/tenseal/tensors/plaintensor.py +++ b/tenseal/tensors/plaintensor.py @@ -128,14 +128,13 @@ def broadcast_(self, shape: List[int]): def transpose(self, axes: List[int] = None): "Copies the transpose to a new tensor" - new_tensor = None + new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) if axes is None: - new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) + return new_tensor.transpose_() elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): - new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype) + return new_tensor.transpose_(axes) else: - raise TypeError("axes must be a list of integers") - return new_tensor.transpose_() + raise TypeError("transpose axes must be a list of integers") def transpose_(self, axes: List[int] = None): "Tries to transpose the tensor" @@ -143,6 +142,8 @@ def transpose_(self, axes: List[int] = None): self.data.transpose_() elif isinstance(axes, list) and all(isinstance(x, int) for x in axes): self.data.transpose_(axes) + else: + raise TypeError("transpose axes must be a list of integers") return self @classmethod diff --git a/tests/cpp/tensors/ckkstensor_test.cpp b/tests/cpp/tensors/ckkstensor_test.cpp index 0dd7efbb..340e184f 100644 --- a/tests/cpp/tensors/ckkstensor_test.cpp +++ b/tests/cpp/tensors/ckkstensor_test.cpp @@ -300,6 +300,34 @@ TEST_P(CKKSTensorTest, TestTranspose) { ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6})); } +TEST_P(CKKSTensorTest, TestTransposeWithAxes) { + auto enc_type = get<1>(GetParam()); + + auto ctx = TenSEALContext::Create(scheme_type::ckks, 8192, -1, + {60, 40, 40, 60}, enc_type); + ASSERT_TRUE(ctx != nullptr); + ctx->generate_galois_keys(); + + auto ldata = + PlainTensor(vector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + vector({2, 3, 2})); + + auto l = CKKSTensor::Create(ctx, ldata, std::pow(2, 40)); + + // Transpose with specified axes + auto res = l->transpose({0, 2, 1}); + ASSERT_THAT(res->shape(), ElementsAreArray({2, 2, 3})); + ASSERT_THAT(l->shape(), ElementsAreArray({2, 3, 2})); + auto decr = res->decrypt(); + ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12})); + + // Transpose inplace with specified axes + l->transpose_inplace({0, 2, 1}); + ASSERT_THAT(l->shape(), ElementsAreArray({2, 2, 3})); + decr = l->decrypt(); + ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12})); +} + TEST_P(CKKSTensorTest, TestSubscript) { auto enc_type = get<1>(GetParam()); diff --git a/tests/python/tenseal/tensors/test_ckks_tensor.py b/tests/python/tenseal/tensors/test_ckks_tensor.py index 74984098..68eeeba6 100644 --- a/tests/python/tenseal/tensors/test_ckks_tensor.py +++ b/tests/python/tenseal/tensors/test_ckks_tensor.py @@ -826,3 +826,27 @@ def test_transpose(context, data, shape): assert tensor.shape == list(expected.shape) result = np.array(tensor.decrypt().tolist()) assert np.allclose(result, expected, rtol=0, atol=0.01) + +@pytest.mark.parametrize( + "data, shape, axes", + [ + ([i for i in range(6)], [1, 2, 3], [0, 2, 1]), + ([i for i in range(12)], [2, 2, 3], [0, 2, 1]), + ([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]), + ], +) +def test_transpose_with_axes(context, data, shape, axes): + tensor = ts.ckks_tensor(context, ts.plain_tensor(data, shape)) + + expected = np.transpose(np.array(data).reshape(shape), axes) + + newt = tensor.transpose(axes) + assert tensor.shape == shape + assert newt.shape == list(expected.shape) + result = np.array(newt.decrypt().tolist()) + assert np.allclose(result, expected, rtol=0, atol=0.01) + + tensor.transpose_(axes) + assert tensor.shape == list(expected.shape) + result = np.array(tensor.decrypt().tolist()) + assert np.allclose(result, expected, rtol=0, atol=0.01) \ No newline at end of file diff --git a/tests/python/tenseal/tensors/test_plain_tensor.py b/tests/python/tenseal/tensors/test_plain_tensor.py index af7a036e..863bd38e 100644 --- a/tests/python/tenseal/tensors/test_plain_tensor.py +++ b/tests/python/tenseal/tensors/test_plain_tensor.py @@ -128,3 +128,25 @@ def test_transpose(data, shape): tensor.transpose_() assert tensor.shape == list(expected.shape) assert np.array(tensor.tolist()).any() == expected.any() + +@pytest.mark.parametrize( + "data, shape, axes", + [ + ([i for i in range(6)], [1, 2, 3], [0, 2, 1]), + ([i for i in range(12)], [2, 2, 3], [0, 2, 1]), + ([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]), + ], +) +def test_transpose(data, shape, axes): + tensor = ts.plain_tensor(data, shape) + + expected = np.transpose(np.array(data).reshape(shape), axes) + + newt = tensor.transpose(axes) + assert tensor.shape == shape + assert newt.shape == list(expected.shape) + assert np.array(newt.tolist()).any() == expected.any() + + tensor.transpose_(axes) + assert tensor.shape == list(expected.shape) + assert np.array(tensor.tolist()).any() == expected.any() \ No newline at end of file