Skip to content

Commit

Permalink
CKKSTensor: complete dot operation (#220)
Browse files Browse the repository at this point in the history
* add 1D-2D dot

* fix

* reshape before broadcast

* fix: return

* ignore protobuf changes

* 2D-1D dot

* separate dot ops again

* lint

* fix: pass the copy

* typos and updates

* lint

* need to copy plain first

Co-authored-by: Bogdan Cebere <[email protected]>
  • Loading branch information
youben11 and bcebere authored Jan 25, 2021
1 parent dc3eb33 commit 7e0895c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf
ignore = dirty
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json
80 changes: 62 additions & 18 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,29 +504,40 @@ shared_ptr<CKKSTensor> CKKSTensor::polyval_inplace(
return shared_from_this();
}

template <typename T>
shared_ptr<CKKSTensor> CKKSTensor::_dot_inplace(
T other, const vector<size_t>& other_shape) {
shared_ptr<CKKSTensor> CKKSTensor::dot_inplace(
const shared_ptr<CKKSTensor>& other) {
auto this_shape = this->shape();
auto other_shape = other->shape();
if (this_shape.size() == 1) {
if (other_shape.size() == 1) { // 1D-1D
// inner product
this->_mul_inplace(other);
this->sum_inplace();
return shared_from_this();
} else if (other_shape.size() == 2) { // 1D-2D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("1D-2D dot isn't implemented yet");
if (this_shape[0] != other_shape[0])
throw invalid_argument("can't perform dot: dimension mismatch");
this->reshape_inplace(vector<size_t>({this_shape[0], 1}));
// TODO: remove broadcast when implemented in _mul
this->_data.broadcast_inplace(other_shape);
this->_mul_inplace(other);
this->sum_inplace();
return shared_from_this();
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else if (this_shape.size() == 2) {
if (other_shape.size() == 1) { // 2D-1D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("2D-1D dot isn't implemented yet");
if (this_shape[1] != other_shape[0])
throw invalid_argument("can't perform dot: dimension mismatch");
auto other_copy =
other->reshape(vector<size_t>({1, other_shape[0]}));
// TODO: remove broadcast when implemented in _mul
other_copy->_data.broadcast_inplace(this_shape);
this->_mul_inplace(other_copy);
this->sum_inplace(1);
return shared_from_this();
} else if (other_shape.size() == 2) { // 2D-2D
this->_matmul_inplace(other);
return shared_from_this();
Expand All @@ -540,18 +551,51 @@ shared_ptr<CKKSTensor> CKKSTensor::_dot_inplace(
}
}

shared_ptr<CKKSTensor> CKKSTensor::dot_inplace(
const shared_ptr<CKKSTensor>& other) {
auto other_shape = other->shape();

return this->_dot_inplace(other, other_shape);
}

shared_ptr<CKKSTensor> CKKSTensor::dot_plain_inplace(
const PlainTensor<double>& other) {
auto this_shape = this->shape();
auto other_shape = other.shape();

return this->_dot_inplace(other, other_shape);
if (this_shape.size() == 1) {
if (other_shape.size() == 1) { // 1D-1D
// inner product
this->_mul_inplace(other);
this->sum_inplace();
return shared_from_this();
} else if (other_shape.size() == 2) { // 1D-2D
if (this_shape[0] != other_shape[0])
throw invalid_argument("can't perform dot: dimension mismatch");
this->reshape_inplace(vector<size_t>({this_shape[0], 1}));
// TODO: remove broadcast when implemented in _mul
this->_data.broadcast_inplace(other_shape);
this->_mul_inplace(other);
this->sum_inplace();
return shared_from_this();
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else if (this_shape.size() == 2) {
if (other_shape.size() == 1) { // 2D-1D
if (this_shape[1] != other_shape[0])
throw invalid_argument("can't perform dot: dimension mismatch");
auto other_copy = other;
other_copy.reshape_inplace(vector<size_t>({1, other_shape[0]}));
// TODO: remove broadcast when implemented in _mul
other_copy.broadcast_inplace(this_shape);
this->_mul_inplace(other_copy);
this->sum_inplace(1);
return shared_from_this();
} else if (other_shape.size() == 2) { // 2D-2D
this->_matmul_inplace(other);
return shared_from_this();
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
}

shared_ptr<CKKSTensor> CKKSTensor::matmul_inplace(
Expand Down
4 changes: 0 additions & 4 deletions tenseal/cpp/tensors/ckkstensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
return this->matmul_plain_inplace(other);
}

template <typename T>
shared_ptr<CKKSTensor> _dot_inplace(T other,
const vector<size_t>& other_shape);

void load_proto(const CKKSTensorProto& buffer);
CKKSTensorProto save_proto() const;
void clear();
Expand Down
16 changes: 16 additions & 0 deletions tests/python/tenseal/tensors/test_ckks_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,22 @@ def test_polynomial_rescale_off(context, data, polynom):
((1,), (1,)),
((3,), (3,)),
((8,), (8,)),
# 1D-2D
((1,), (1, 2)),
((2,), (2, 2)),
((2,), (2, 4)),
((2,), (2, 5)),
((2,), (2, 3)),
((5,), (5, 1)),
((7,), (7, 5)),
# 2D-1D
((2, 1), (1,)),
((2, 2), (2,)),
((4, 2), (2,)),
((3, 2), (2,)),
((5, 2), (2,)),
((3, 5), (5,)),
((3, 7), (7,)),
# 2D-2D
((2, 1), (1, 2)),
((2, 2), (2, 2)),
Expand Down

0 comments on commit 7e0895c

Please sign in to comment.