From 4b43ad7a2ea3fdd68e8780bfbcb61022fbaa99f2 Mon Sep 17 00:00:00 2001 From: yjkao Date: Wed, 11 Dec 2024 15:48:24 +0800 Subject: [PATCH 01/11] Refactor SearchTree and PseudoUniTensor classes for improved memory management and functionality - Renamed PsudoUniTensor to PseudoUniTensor for consistency. - Replaced raw pointers with smart pointers (std::unique_ptr) in PseudoUniTensor to enhance memory safety. - Updated constructors and assignment operators to support move semantics. - Introduced helper functions for optimal tree solving and connected components detection. - Enhanced SearchTree class to utilize the new PseudoUniTensor structure and improved error handling. - Added unit tests for SearchTree functionality, covering various scenarios including basic contraction orders and error cases. This refactor aims to streamline tensor contraction processes and improve overall code maintainability. --- include/search_tree.hpp | 128 +++++++-------- src/RegularNetwork.cpp | 38 ++--- src/search_tree.cpp | 321 +++++++++++++++++++++++++++++-------- tests/CMakeLists.txt | 5 + tests/search_tree_test.cpp | 156 ++++++++++++++++++ 5 files changed, 490 insertions(+), 158 deletions(-) create mode 100644 tests/search_tree_test.cpp diff --git a/include/search_tree.hpp b/include/search_tree.hpp index a042706b..fa8c7767 100644 --- a/include/search_tree.hpp +++ b/include/search_tree.hpp @@ -7,98 +7,82 @@ #include #include #include +#include +#include +#include #include "UniTensor.hpp" -#ifdef BACKEND_TORCH -#else - namespace cytnx { - /// @cond - class PsudoUniTensor { + + using IndexSet = std::bitset<128>; + + class PseudoUniTensor { public: - // UniTensor utensor; //don't worry about copy, because everything are references in cytnx! + bool isLeaf; + + // Leaf node data std::vector labels; std::vector shape; bool is_assigned; - PsudoUniTensor *left; - PsudoUniTensor *right; - PsudoUniTensor *root; + cytnx_uint64 tensorIndex; + + // Internal node data + std::unique_ptr left; + std::unique_ptr right; + cytnx_float cost; cytnx_uint64 ID; - std::string accu_str; - PsudoUniTensor() - : is_assigned(false), left(nullptr), right(nullptr), root(nullptr), cost(0), ID(0){}; - PsudoUniTensor(const PsudoUniTensor &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->labels = rhs.labels; - this->shape = rhs.shape; - this->is_assigned = rhs.is_assigned; - this->cost = rhs.cost; - this->accu_str = rhs.accu_str; - this->ID = rhs.ID; - } - PsudoUniTensor &operator==(const PsudoUniTensor &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->labels = rhs.labels; - this->shape = rhs.shape; - this->is_assigned = rhs.is_assigned; - this->cost = rhs.cost; - this->accu_str = rhs.accu_str; - this->ID = rhs.ID; - return *this; - } - void from_utensor(const UniTensor &in_uten) { - this->labels = in_uten.labels(); - this->shape = in_uten.shape(); - this->is_assigned = true; - } - void clear_utensor() { - this->is_assigned = false; - this->labels.clear(); - this->shape.clear(); - this->ID = 0; - this->cost = 0; - this->accu_str = ""; - } - void set_ID(const cytnx_int64 &ID) { this->ID = ID; } + // Constructors + explicit PseudoUniTensor(cytnx_uint64 index = 0) + : isLeaf(true), tensorIndex(index), is_assigned(false), cost(0), ID(1ULL << index), + accu_str(std::to_string(index)) {} + + PseudoUniTensor(std::unique_ptr l, std::unique_ptr r) + : isLeaf(false), left(std::move(l)), right(std::move(r)), cost(0), ID(0) {} + + // Copy and move constructors and assignment operators + PseudoUniTensor(const PseudoUniTensor& rhs); + PseudoUniTensor(PseudoUniTensor&& rhs) noexcept; + PseudoUniTensor& operator=(const PseudoUniTensor& rhs); + PseudoUniTensor& operator=(PseudoUniTensor&& rhs) noexcept; + ~PseudoUniTensor() = default; + + void from_utensor(const UniTensor& in_uten); + void clear_utensor(); }; + namespace OptimalTreeSolver { + std::unique_ptr solve(const std::vector& tensors, + bool verbose = false); + } + class SearchTree { public: - std::vector> nodes_container; - // std::vector nodes_container; // this contains intermediate layer. - std::vector base_nodes; // this is the button layer. - - SearchTree(){}; - SearchTree(const SearchTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - } - SearchTree &operator==(const SearchTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - return *this; - } + + std::vector base_nodes; - // clear all the elements in the whole tree. + SearchTree() = default; void clear() { - nodes_container.clear(); - base_nodes.clear(); - // nodes_container.reserve(1024); + root_ptr.reset(); + base_nodes.clear(); } - // clear all the intermediate layer, leave all the base_nodes intact. - // and reset the root pointer on the base ondes - void reset_search_order() { nodes_container.clear(); } + void reset_search_order() { root_ptr.reset(); } void search_order(); + + std::vector> get_root() const { + return {{root_ptr.get()}}; + } + + private: + std::unique_ptr root_ptr; }; - /// @endcond + + // Helper functions declarations + cytnx_float get_cost(const PseudoUniTensor& t1, const PseudoUniTensor& t2); + PseudoUniTensor pContract(PseudoUniTensor& t1, PseudoUniTensor& t2); + } // namespace cytnx -#endif #endif // CYTNX_SEARCH_TREE_H_ diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index e8c8e1fb..bddce1fe 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -230,24 +230,24 @@ namespace cytnx { vector> CtTree_to_eisumpath(ContractionTree CtTree, vector tns) { vector> path; - stack stk; - Node *root = &(CtTree.nodes_container.back()); + stack> stk; + shared_ptr root = make_shared(CtTree.nodes_container.back()); int ly = 0; bool ict; do { while ((root != nullptr)) { - if (root->right != nullptr) stk.push(root->right); + if (root->right != nullptr) stk.push(make_shared(*(root->right))); stk.push(root); - root = root->left; + root = make_shared(*(root->left)); } root = stk.top(); stk.pop(); ict = true; if ((root->right != nullptr) && !stk.empty()) { - if (stk.top() == root->right) { + if (stk.top()->name == root->right->name) { stk.pop(); stk.push(root); - root = root->right; + root = make_shared(*(root->right)); ict = false; } } @@ -943,7 +943,7 @@ namespace cytnx { Stree.base_nodes[t].accu_str = this->names[t]; } Stree.search_order(); - return Stree.nodes_container.back()[0].accu_str; + return Stree.get_root().back()[0]->accu_str; } UniTensor RegularNetwork::Launch() { @@ -969,17 +969,17 @@ namespace cytnx { // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + shared_ptr root = make_shared(this->CtTree.nodes_container.back()); int ly = 0; bool ict; do { // move the lmost while ((root != nullptr)) { - if (root->right != nullptr) stk.push(root->right); + if (root->right != nullptr) stk.push(make_shared(*(root->right))); stk.push(root); - root = root->left; + root = make_shared(*(root->left)); } root = stk.top(); @@ -987,10 +987,10 @@ namespace cytnx { // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); ict = true; if ((root->right != nullptr) && !stk.empty()) { - if (stk.top() == root->right) { + if (stk.top()->name == root->right->name) { stk.pop(); stk.push(root); - root = root->right; + root = make_shared(*(root->right)); ict = false; } } @@ -1100,17 +1100,17 @@ namespace cytnx { } // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + shared_ptr root = make_shared(this->CtTree.nodes_container.back()); int ly = 0; bool ict; do { // move the lmost while ((root != nullptr)) { - if (root->right != nullptr) stk.push(root->right); + if (root->right != nullptr) stk.push(make_shared(*(root->right))); stk.push(root); - root = root->left; + root = make_shared(*(root->left)); } root = stk.top(); @@ -1118,10 +1118,10 @@ namespace cytnx { // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); ict = true; if ((root->right != nullptr) && !stk.empty()) { - if (stk.top() == root->right) { + if (stk.top()->name == root->right->name) { stk.pop(); stk.push(root); - root = root->right; + root = make_shared(*(root->right)); ict = false; } } diff --git a/src/search_tree.cpp b/src/search_tree.cpp index a3cc9817..385b8baa 100644 --- a/src/search_tree.cpp +++ b/src/search_tree.cpp @@ -7,8 +7,8 @@ using namespace std; #else namespace cytnx { - - cytnx_float get_cost(const PsudoUniTensor &t1, const PsudoUniTensor &t2) { + // helper functions + cytnx_float get_cost(const PseudoUniTensor& t1, const PseudoUniTensor& t2) { cytnx_float cost = 1; vector shape1 = t1.shape; vector shape2 = t2.shape; @@ -30,89 +30,276 @@ namespace cytnx { return cost + t1.cost + t2.cost; } - PsudoUniTensor pContract(PsudoUniTensor &t1, PsudoUniTensor &t2) { - PsudoUniTensor t3; - t3.ID = t1.ID ^ t2.ID; - t3.cost = get_cost(t1, t2); + PseudoUniTensor pContract(PseudoUniTensor& t1, PseudoUniTensor& t2) { + PseudoUniTensor t3(0); // Initialize with index 0 + + t3.ID = t1.ID ^ t2.ID; // XOR of IDs to track contracted tensors + t3.cost = get_cost(t1, t2); // Calculate contraction cost + + // Find common labels between t1 and t2 vector loc1, loc2; vector comm_lbl; vec_intersect_(comm_lbl, t1.labels, t2.labels, loc1, loc2); - t3.shape = vec_concatenate(vec_erase(t1.shape, loc1), vec_erase(t2.shape, loc2)); - t3.labels = vec_concatenate(vec_erase(t1.labels, loc1), vec_erase(t2.labels, loc2)); + + // New shape is concatenation of non-contracted dimensions + t3.shape = vec_concatenate(vec_erase(t1.shape, loc1), + vec_erase(t2.shape, loc2)); + + // New labels are concatenation of non-contracted labels + t3.labels = vec_concatenate(vec_erase(t1.labels, loc1), + vec_erase(t2.labels, loc2)); + + // Set accumulation string using the original accu_str if available + if (t1.accu_str.empty()) t1.accu_str = std::to_string(t1.tensorIndex); + if (t2.accu_str.empty()) t2.accu_str = std::to_string(t2.tensorIndex); t3.accu_str = "(" + t1.accu_str + "," + t2.accu_str + ")"; + + // Set as internal node + t3.isLeaf = false; + t3.left = std::make_unique(t1); + t3.right = std::make_unique(t2); + return t3; } + namespace OptimalTreeSolver { + // Helper function to find connected components using DFS + void dfs(size_t node, const std::vector& adjacencyMatrix, + IndexSet& visited, std::vector& component) { + visited.set(node); + component.push_back(node); + + for (size_t i = 0; i < adjacencyMatrix.size(); ++i) { + if (adjacencyMatrix[node].test(i) && !visited.test(i)) { + dfs(i, adjacencyMatrix, visited, component); + } + } + } + + // Find connected components in the tensor network + std::vector> findConnectedComponents( + const std::vector& adjacencyMatrix) { + std::vector> components; + IndexSet visited; + + for (size_t i = 0; i < adjacencyMatrix.size(); ++i) { + if (!visited.test(i)) { + std::vector component; + dfs(i, adjacencyMatrix, visited, component); + components.push_back(component); + } + } + + return components; + } + + std::unique_ptr solve(const std::vector& tensors, + bool verbose) { + if (tensors.empty()) { + return nullptr; + } + + // Initialize nodes with copies of input tensors + std::vector> nodes; + nodes.reserve(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + auto node = std::make_unique(i); + *node = tensors[i]; + node->ID = 1ULL << i; + nodes.push_back(std::move(node)); + } + + const size_t n = nodes.size(); + // Build adjacency matrix with proper size + std::vector adjacencyMatrix(n); + + // Fill adjacency matrix + for (size_t i = 0; i < n; ++i) { + for (size_t j = i + 1; j < n; ++j) { + // Find common labels + vector common_lbl; + vector comm_idx1, comm_idx2; + vec_intersect_(common_lbl, nodes[i]->labels, nodes[j]->labels, comm_idx1, comm_idx2); + + if (!common_lbl.empty()) { + adjacencyMatrix[i].set(j); + adjacencyMatrix[j].set(i); + } + } + } + + // Find connected components + auto components = findConnectedComponents(adjacencyMatrix); + if (verbose && components.size() > 1) { + std::cout << "Found " << components.size() << " disconnected components" << std::endl; + } + + // Process each component separately + std::vector> component_results; + for (const auto& component : components) { + // Extract nodes for this component + std::vector> component_nodes; + std::vector remaining_indices = component; + + while (remaining_indices.size() > 1) { + // Find best contraction pair within component + size_t best_i = 0, best_j = 1; + cytnx_float min_cost = std::numeric_limits::max(); + + for (size_t ii = 0; ii < remaining_indices.size(); ++ii) { + size_t i = remaining_indices.at(ii); + for (size_t jj = ii + 1; jj < remaining_indices.size(); ++jj) { + size_t j = remaining_indices.at(jj); + if (adjacencyMatrix[i].test(j)) { + cytnx_float cost = get_cost(*nodes[i], *nodes[j]); + if (cost < min_cost) { + min_cost = cost; + best_i = ii; + best_j = jj; + } + } + } + } + + if (verbose) { + std::cout << "Contracting nodes " << remaining_indices[best_i] << " and " + << remaining_indices[best_j] << " with cost " << min_cost << std::endl; + } + + // Contract best pair + auto left = std::move(nodes[remaining_indices[best_i]]); + auto right = std::move(nodes[remaining_indices[best_j]]); + auto result = pContract(*left, *right); + auto result_ptr = std::make_unique(std::move(result)); + + // Update remaining indices + remaining_indices.erase(remaining_indices.begin() + best_j); + remaining_indices.erase(remaining_indices.begin() + best_i); + + // Store result in original nodes vector + size_t new_idx = nodes.size(); + nodes.push_back(std::move(result_ptr)); + remaining_indices.push_back(new_idx); + } + + // Store the component result + component_results.push_back(std::move(nodes[remaining_indices[0]])); + } + + // If there were multiple components, combine them + while (component_results.size() > 1) { + // Create new node for combining components + auto new_node = std::make_unique(); + new_node->isLeaf = false; + + // Move the first two components as children + new_node->left = std::move(component_results[0]); + new_node->right = std::move(component_results[1]); + + // Calculate cost and set properties + new_node->cost = get_cost(*new_node->left, *new_node->right); + new_node->accu_str = "(" + new_node->left->accu_str + "," + new_node->right->accu_str + ")"; + new_node->ID = new_node->left->ID ^ new_node->right->ID; + + // Update component list + component_results.erase(component_results.begin(), component_results.begin() + 2); + component_results.insert(component_results.begin(), std::move(new_node)); + } + + return std::move(component_results[0]); + } + } // namespace OptimalTreeSolver + void SearchTree::search_order() { this->reset_search_order(); - if (this->base_nodes.size() == 0) { - cytnx_error_msg(true, "[ERROR][SearchTree] no base node exist.%s", "\n"); + if (this->base_nodes.size() == 1 || this->base_nodes.size() == 0 ) { + cytnx_error_msg(true, "[ERROR][SearchTree] need at least 2 nodes.%s", "\n"); } - cytnx_int64 pid = 0; - this->nodes_container.resize(this->base_nodes.size()); - //[Regiving each base nodes it's own ID]: - for (cytnx_uint64 i = 0; i < this->base_nodes.size(); i++) { - this->base_nodes[i].set_ID(pow(2, i)); - this->nodes_container[i].reserve(this->base_nodes.size() * 2); // try - } + // Run optimal tree solver directly with base_nodes + root_ptr = OptimalTreeSolver::solve(base_nodes, true); + } + + PseudoUniTensor& PseudoUniTensor::operator=(const PseudoUniTensor& rhs) { + if (this != &rhs) { + isLeaf = rhs.isLeaf; + labels = rhs.labels; + shape = rhs.shape; + is_assigned = rhs.is_assigned; + tensorIndex = rhs.tensorIndex; + cost = rhs.cost; + ID = rhs.ID; + accu_str = rhs.accu_str; + + if (!isLeaf) { + if (rhs.left) + left = std::make_unique(*rhs.left); + else + left = nullptr; - // init first layer - for (cytnx_uint64 t = 0; t < this->base_nodes.size(); t++) { - this->nodes_container[0].push_back(this->base_nodes[t]); + if (rhs.right) + right = std::make_unique(*rhs.right); + else + right = nullptr; + } } + return *this; + } - bool secondtimescan = 0; - while (this->nodes_container.back().size() == - 0) { // I can't see the need of this while loop before using secondtimescan - for (int c = 1; c < this->base_nodes.size(); c++) { - for (int d1 = 0; d1 < (c + 1) / 2; d1++) { - int d2 = c - d1 - 1; - int n1 = this->nodes_container[d1].size(); - int n2 = this->nodes_container[d2].size(); - for (int i1 = 0; i1 < n1; i1++) { - int i2_start = (d1 == d2) ? i1 + 1 : 0; - for (int i2 = i2_start; i2 < n2; i2++) { - PsudoUniTensor &t1 = this->nodes_container[d1][i1]; - PsudoUniTensor &t2 = this->nodes_container[d2][i2]; - - // No common labels - // If it's the secondtimescan, that's probably because there're need of Kron - // operations. - if (!secondtimescan and cytnx::vec_intersect(t1.labels, t2.labels).size() == 0) - continue; - // overlap - if ((t1.ID & t2.ID) > 0) continue; - - PsudoUniTensor t3 = pContract(t1, t2); - bool exist = false; - for (int i = 0; i < nodes_container[c].size(); i++) { - if (t3.ID == nodes_container[c][i].ID) { - exist = true; - if (t3.cost < nodes_container[c][i].cost) { - nodes_container[c][i] = t3; - t1.root = &nodes_container[c][i]; - t2.root = &nodes_container[c][i]; - } - break; - } - } // i + PseudoUniTensor::PseudoUniTensor(const PseudoUniTensor& rhs) + : isLeaf(rhs.isLeaf), + labels(rhs.labels), + shape(rhs.shape), + is_assigned(rhs.is_assigned), + tensorIndex(rhs.tensorIndex), + cost(rhs.cost), + ID(rhs.ID), + accu_str(rhs.accu_str) { + if (!isLeaf) { + if (rhs.left) left = std::make_unique(*rhs.left); + if (rhs.right) right = std::make_unique(*rhs.right); + } + } - if (!exist) { - nodes_container[c].push_back(t3); - t1.root = &nodes_container[c].back(); - t2.root = &nodes_container[c].back(); - } + cytnx::PseudoUniTensor::PseudoUniTensor(PseudoUniTensor&& rhs) noexcept + : isLeaf(rhs.isLeaf), + labels(std::move(rhs.labels)), + shape(std::move(rhs.shape)), + is_assigned(rhs.is_assigned), + tensorIndex(rhs.tensorIndex), + left(std::move(rhs.left)), + right(std::move(rhs.right)), + cost(rhs.cost), + ID(rhs.ID), + accu_str(std::move(rhs.accu_str)) {} - } // for i2 - } // for i1 - } // for d1 - } // for c - secondtimescan = 1; - } // while + PseudoUniTensor& PseudoUniTensor::operator=(PseudoUniTensor&& rhs) noexcept { + if (this != &rhs) { + isLeaf = rhs.isLeaf; + labels = std::move(rhs.labels); + shape = std::move(rhs.shape); + is_assigned = rhs.is_assigned; + tensorIndex = rhs.tensorIndex; + left = std::move(rhs.left); + right = std::move(rhs.right); + cost = rhs.cost; + ID = rhs.ID; + accu_str = std::move(rhs.accu_str); + } + return *this; + } - // cout << nodes_container.back()[0].accu_str << endl; + void PseudoUniTensor::from_utensor(const UniTensor& in_uten) { + isLeaf = true; + labels = in_uten.labels(); + shape = in_uten.shape(); + is_assigned = true; + // Other members keep their default/current values + left = nullptr; + right = nullptr; } + + } // namespace cytnx #endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 86b6f31c..643508d9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -23,6 +23,7 @@ add_executable( Accessor_test.cpp Tensor_test.cpp Storage_test.cpp + search_tree_test.cpp utils_test/vec_concatenate.cpp utils_test/vec_unique.cpp utils/getNconParameter.cpp @@ -55,6 +56,10 @@ target_link_libraries( gtest ) target_link_libraries(test_main cytnx) + +add_compile_options(-fsanitize=address) +add_link_options(-fsanitize=address) + #target_link_libraries(test_main PUBLIC "-lgcov --coverage") include(GoogleTest) gtest_discover_tests(test_main diff --git a/tests/search_tree_test.cpp b/tests/search_tree_test.cpp new file mode 100644 index 00000000..8ab0c772 --- /dev/null +++ b/tests/search_tree_test.cpp @@ -0,0 +1,156 @@ +#include "cytnx.hpp" +#include "search_tree.hpp" +#include + +using namespace cytnx; +using namespace std; + +class SearchTreeTest : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(SearchTreeTest, BasicSearchOrder) { + // Create a simple network of 3 tensors + vector tensors; + + // Create tensor 1 with shape [2,3] and labels ["i","j"] + PseudoUniTensor t1(0); + t1.shape = {2, 3}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tensors.push_back(t1); + + // Create tensor 2 with shape [3,4] and labels ["j","k"] + PseudoUniTensor t2(1); + t2.shape = {3, 4}; + t2.labels = {"j", "k"}; + t2.cost = 0; + tensors.push_back(t2); + + // Create tensor 3 with shape [4,2] and labels ["k","i"] + PseudoUniTensor t3(2); + t3.shape = {4, 2}; + t3.labels = {"k", "i"}; + t3.cost = 0; + tensors.push_back(t3); + + // Create search tree + SearchTree tree; + tree.base_nodes = tensors; + + // Find optimal contraction order + tree.search_order(); + + // Get result node + auto result = tree.get_root().back()[0]; + + // Verify final cost is optimal + EXPECT_EQ(result->cost, 32); // 2*3*4 + 2*2*4 = 24 + 16 = 40 flops, cost = 32 + + // Verify contraction string format + EXPECT_EQ(result->accu_str, "(2,(0,1))"); +} + +TEST_F(SearchTreeTest, BasicSearchOrder2) { + // Create a network of 4 tensors to test more complex contraction ordering + vector tensors; + + // Create tensor 1 with shape [2,10] and labels ["i","j"] + // This will connect with tensor 2 through index j + PseudoUniTensor t1(0); + t1.shape = {2, 10}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tensors.push_back(t1); + + // Create tensor 2 with shape [10,4,8] and labels ["j","k","m"] + // This connects with t1 through j, t3 through k, and t4 through m + PseudoUniTensor t2(1); + t2.shape = {10, 4, 8}; + t2.labels = {"j", "k", "m"}; + t2.cost = 0; + tensors.push_back(t2); + + // Create tensor 3 with shape [4,2] and labels ["k","l"] + // This connects with t2 through k and t4 through l + PseudoUniTensor t3(2); + t3.shape = {4, 2}; + t3.labels = {"k", "l"}; + t3.cost = 0; + tensors.push_back(t3); + + // Create tensor 4 with shape [2,8,7] and labels ["l","m","n"] + // This connects with t3 through l and t2 through m + // The n index remains uncontracted as an external index + PseudoUniTensor t4(3); + t4.shape = {2, 8, 7}; + t4.labels = {"l", "m", "n"}; + t4.cost = 0; + tensors.push_back(t4); + + // Create search tree and set base nodes + SearchTree tree; + tree.base_nodes = tensors; + + // Find optimal contraction order using search algorithm + tree.search_order(); + + // Get the final contracted result node + auto result = tree.get_root().back()[0]; + + // Verify the final contraction cost is optimal + // The optimal sequence contracts (t1,t2) first, then t3, then t4 + EXPECT_EQ(result->cost, 1536); + + // Verify the contraction sequence string matches expected optimal order + // Format is (tensor_id,(tensor_id,tensor_id)) showing order of pairwise contractions + cout << result->accu_str << endl; + EXPECT_EQ(result->accu_str, "((2,3),(0,1))"); +} + +TEST_F(SearchTreeTest, EmptyTree) { + SearchTree tree; + + // Should throw error when searching empty tree + EXPECT_THROW(tree.search_order(), std::logic_error); +} + +TEST_F(SearchTreeTest, SingleNode) { + SearchTree tree; + + // Add single node + PseudoUniTensor t1(0); + t1.shape = {2, 2}; + t1.labels = {"i", "i"}; + t1.cost = 0; + tree.base_nodes.push_back(t1); + + // Should throw error - need at least 2 nodes + EXPECT_THROW(tree.search_order(), std::logic_error); +} + +TEST_F(SearchTreeTest, DisconnectedNetwork) { + SearchTree tree; + + // Create two tensors with no common indices + PseudoUniTensor t1(0); + t1.shape = {2, 3}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tree.base_nodes.push_back(t1); + + PseudoUniTensor t2(1); + t2.shape = {4, 5}; + t2.labels = {"k", "l"}; + t2.cost = 0; + tree.base_nodes.push_back(t2); + + // Should still work but with higher cost due to direct product + tree.search_order(); + auto result = tree.get_root().back()[0]; + + EXPECT_EQ(result->cost, 120); // 2*3*4*5 = 120 + EXPECT_EQ(result->accu_str, "(0,1)"); +} From 8c0542d24fd7a3f22858fd5123ccda035e16a77a Mon Sep 17 00:00:00 2001 From: yjkao Date: Wed, 11 Dec 2024 23:05:15 +0800 Subject: [PATCH 02/11] Refactor ContractionTree and Node classes to use smart pointers for improved memory management - Updated Node class to inherit from std::enable_shared_from_this for better shared pointer management. - Replaced raw pointers with std::shared_ptr and std::weak_ptr in Node to prevent memory leaks and circular references. - Modified constructors and assignment operators in Node and ContractionTree to support smart pointer usage. - Adjusted methods in ContractionTree to work with shared pointers, enhancing safety and clarity. - Updated related code in RegularGncon and RegularNetwork to accommodate changes in Node handling. This refactor aims to enhance memory safety and maintainability in tensor contraction processes. --- include/contraction_tree.hpp | 120 +++++++++++++++++++---------------- src/RegularGncon.cpp | 30 ++++++--- src/RegularNetwork.cpp | 105 ++++++++++++++---------------- src/contraction_tree.cpp | 38 ++++++----- 4 files changed, 152 insertions(+), 141 deletions(-) diff --git a/include/contraction_tree.hpp b/include/contraction_tree.hpp index 2c4f8522..b629a7a0 100644 --- a/include/contraction_tree.hpp +++ b/include/contraction_tree.hpp @@ -8,97 +8,107 @@ #include #include #include +#include #ifdef BACKEND_TORCH #else namespace cytnx { /// @cond - class Node { + class Node : public std::enable_shared_from_this { public: UniTensor utensor; // don't worry about copy, because everything are references in cytnx! bool is_assigned; - Node *left; - Node *right; + std::shared_ptr left; + std::shared_ptr right; + std::weak_ptr root; // Use weak_ptr to avoid circular references std::string name; - Node *root; - Node() : is_assigned(false), left(nullptr), right(nullptr), root(nullptr){}; - Node(const Node &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->utensor = rhs.utensor; - this->is_assigned = rhs.is_assigned; + Node() : is_assigned(false) {} + + Node(const Node& rhs) + : utensor(rhs.utensor), + is_assigned(rhs.is_assigned), + left(rhs.left), + right(rhs.right), + name(rhs.name) { + if (auto r = rhs.root.lock()) { + root = r; + } } - Node &operator==(const Node &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->utensor = rhs.utensor; - this->is_assigned = rhs.is_assigned; + + Node& operator=(const Node& rhs) { + if (this != &rhs) { + utensor = rhs.utensor; + is_assigned = rhs.is_assigned; + left = rhs.left; + right = rhs.right; + name = rhs.name; + if (auto r = rhs.root.lock()) { + root = r; + } + } return *this; } - Node(Node *in_left, Node *in_right, const UniTensor &in_uten = UniTensor()) - : is_assigned(false), left(nullptr), right(nullptr), root(nullptr) { - this->left = in_left; - this->right = in_right; - in_left->root = this; - in_right->root = this; - if (in_uten.uten_type() != UTenType.Void) this->utensor = in_uten; + + Node(std::shared_ptr in_left, std::shared_ptr in_right, + const UniTensor& in_uten = UniTensor()) + : is_assigned(false) { + left = in_left; + right = in_right; + if (in_uten.uten_type() != UTenType.Void) { + utensor = in_uten; + } + + // Set root pointers using shared_from_this() + if (left) left->root = shared_from_this(); + if (right) right->root = shared_from_this(); } - void assign_utensor(const UniTensor &in_uten) { - this->utensor = in_uten; - this->is_assigned = true; + + void assign_utensor(const UniTensor& in_uten) { + utensor = in_uten; + is_assigned = true; } + void clear_utensor() { - this->is_assigned = false; - this->utensor = UniTensor(); + is_assigned = false; + utensor = UniTensor(); } }; class ContractionTree { public: - std::vector nodes_container; // this contains intermediate layer. - std::vector base_nodes; // this is the button layer. + std::vector> nodes_container; // intermediate layer + std::vector> base_nodes; // bottom layer - ContractionTree(){}; - ContractionTree(const ContractionTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - } - ContractionTree &operator==(const ContractionTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - return *this; - } + ContractionTree() = default; + ContractionTree(const ContractionTree&) = default; + ContractionTree& operator=(const ContractionTree&) = default; - // clear all the elements in the whole tree. void clear() { nodes_container.clear(); base_nodes.clear(); - // nodes_container.reserve(1024); } - // clear all the intermediate layer, leave all the base_nodes intact. - // and reset the root pointer on the base ondes + void reset_contraction_order() { nodes_container.clear(); - for (cytnx_uint64 i = 0; i < base_nodes.size(); i++) { - base_nodes[i].root = nullptr; + for (auto& node : base_nodes) { + node->root.reset(); // Dereference shared_ptr with -> } - // nodes_container.reserve(1024); } + void reset_nodes() { - // reset all nodes but keep the skeleton - for (cytnx_uint64 i = 0; i < this->nodes_container.size(); i++) { - this->nodes_container[i].clear_utensor(); + for (auto& node : nodes_container) { + node->clear_utensor(); } - for (cytnx_uint64 i = 0; i < this->base_nodes.size(); i++) { - this->base_nodes[i].clear_utensor(); + for (auto& node : base_nodes) { + node->clear_utensor(); } } + void build_default_contraction_tree(); - void build_contraction_tree_by_tokens(const std::map &name2pos, - const std::vector &tokens); + void build_contraction_tree_by_tokens( + const std::map& name2pos, + const std::vector& tokens); }; /// @endcond } // namespace cytnx diff --git a/src/RegularGncon.cpp b/src/RegularGncon.cpp index 9a95ba8b..bb59109c 100644 --- a/src/RegularGncon.cpp +++ b/src/RegularGncon.cpp @@ -299,6 +299,17 @@ namespace cytnx { // // put tensor: // for (int i = 0; i < utensors.size(); i++) this->tensors[i] = utensors[i]; + + // Update node creation + this->tensors.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create nodes using make_shared + for(size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } } void RegularGncon::FromString(const std::vector &contents) { @@ -597,14 +608,15 @@ namespace cytnx { string RegularGncon::getOptimalOrder() { // Creat a SearchTree to search for optim contraction order. SearchTree Stree; + Stree.base_nodes.clear(); Stree.base_nodes.resize(this->tensors.size()); + for (cytnx_uint64 t = 0; t < this->tensors.size(); t++) { - // Stree.base_nodes[t].from_utensor(this->tensors[t]); //create psudotensors from base tensors - Stree.base_nodes[t].from_utensor(CtTree.base_nodes[t].utensor); + Stree.base_nodes[t].from_utensor(this->tensors[t]); Stree.base_nodes[t].accu_str = this->names[t]; } Stree.search_order(); - return Stree.nodes_container.back()[0].accu_str; + return Stree.get_root().back()[0]->accu_str; } UniTensor RegularGncon::Launch(const bool &optimal, const string &contract_order /*default ""*/) { @@ -636,10 +648,10 @@ namespace cytnx { // modify the label of unitensor (shared): // this->tensors[idx].set_labels(this->label_arr[idx]);//this conflict - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict // this->CtTree.base_nodes[idx].name = this->tensors[idx].name(); - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; // cout << this->tensors[idx].name() << " " << idx << "from dict:" << // this->name2pos[this->tensors[idx].name()] << endl; @@ -667,8 +679,8 @@ namespace cytnx { // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = this->CtTree.nodes_container.back(); int ly = 0; bool ict; @@ -720,7 +732,7 @@ namespace cytnx { } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // std::cout << out << std::endl; // out.print_diagram(); @@ -734,7 +746,7 @@ namespace cytnx { // 6. permute accroding to pre-set labels: if (TOUT_labels.size()) { - out.permute_(TOUT_labels, TOUT_iBondNum, true); + out.permute_(TOUT_labels, TOUT_iBondNum); } // UniTensor out; diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index bddce1fe..620d063d 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -230,24 +230,24 @@ namespace cytnx { vector> CtTree_to_eisumpath(ContractionTree CtTree, vector tns) { vector> path; - stack> stk; - shared_ptr root = make_shared(CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = CtTree.nodes_container.back(); int ly = 0; bool ict; do { while ((root != nullptr)) { - if (root->right != nullptr) stk.push(make_shared(*(root->right))); + if (root->right != nullptr) stk.push(root->right); stk.push(root); - root = make_shared(*(root->left)); + root = root->left; } root = stk.top(); stk.pop(); ict = true; if ((root->right != nullptr) && !stk.empty()) { - if (stk.top()->name == root->right->name) { + if (stk.top() == root->right) { stk.pop(); stk.push(root); - root = make_shared(*(root->right)); + root = root->right; ict = false; } } @@ -632,7 +632,7 @@ namespace cytnx { vector names; for (int i = 0; i < this->names.size(); i++) { names.push_back(this->names[i]); - CtTree.base_nodes[i].name = this->names[i]; + CtTree.base_nodes[i]->name = this->names[i]; } if (ORDER_tokens.size() != 0) { CtTree.build_contraction_tree_by_tokens(this->name2pos, ORDER_tokens); @@ -955,9 +955,9 @@ namespace cytnx { // cpu workflow for (cytnx_uint64 idx = 0; idx < this->tensors.size(); idx++) { - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; } // 1.5 contraction order: if (ORDER_tokens.size() != 0) { @@ -969,59 +969,45 @@ namespace cytnx { // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack> stk; - shared_ptr root = make_shared(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = this->CtTree.nodes_container.back(); int ly = 0; bool ict; do { - // move the lmost - while ((root != nullptr)) { - if (root->right != nullptr) stk.push(make_shared(*(root->right))); + // move the leftmost + while (root != nullptr) { + if (root->right) stk.push(root->right); stk.push(root); - root = make_shared(*(root->left)); + root = root->left; } root = stk.top(); stk.pop(); - // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); + ict = true; - if ((root->right != nullptr) && !stk.empty()) { - if (stk.top()->name == root->right->name) { + if (root->right && !stk.empty()) { + if (stk.top() == root->right) { // This comparison now works with shared_ptr stk.pop(); stk.push(root); - root = make_shared(*(root->right)); + root = root->right; ict = false; } } - if (ict) { - // process! - // cout << "OK" << endl; - if ((root->right != nullptr) && (root->left != nullptr)) { - // cout << "L,R::\n"; - // root->left->utensor.print_diagram(1); - // root->right->utensor.print_diagram(1); + if (ict) { + if (root->right && root->left) { root->utensor = Contract(root->left->utensor, root->right->utensor); - // root->left->utensor.print_diagram(); root->right->utensor.print_diagram(); - // root->utensor.print_diagram(); root->utensor.set_name(root->left->utensor.name() + - // root->right->utensor.name()); - root->left->clear_utensor(); // remove intermediate unitensor to save heap space - root->right->clear_utensor(); // remove intermediate unitensor to save heap space + root->left->clear_utensor(); + root->right->clear_utensor(); root->is_assigned = true; - // cout << "contract!" << endl; } - root = nullptr; } - - // cout.flush(); - // break; - } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // cout << out << endl; // out.print_diagram(); @@ -1087,9 +1073,9 @@ namespace cytnx { } #else for (cytnx_uint64 idx = 0; idx < this->tensors.size(); idx++) { - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; } // 1.5 contraction order: if (ORDER_tokens.size() != 0) { @@ -1100,36 +1086,34 @@ namespace cytnx { } // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack> stk; - shared_ptr root = make_shared(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = std::make_shared(this->CtTree.nodes_container.back()); int ly = 0; bool ict; do { - // move the lmost - while ((root != nullptr)) { - if (root->right != nullptr) stk.push(make_shared(*(root->right))); + // move the leftmost + while (root != nullptr) { + if (root->right) stk.push(root->right); stk.push(root); - root = make_shared(*(root->left)); + root = root->left; } root = stk.top(); stk.pop(); - // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); + ict = true; - if ((root->right != nullptr) && !stk.empty()) { - if (stk.top()->name == root->right->name) { + if (root->right && !stk.empty()) { + if (stk.top() == root->right) { // This comparison now works with shared_ptr stk.pop(); stk.push(root); - root = make_shared(*(root->right)); + root = root->right; ict = false; } } - if (ict) { - // process! - // cout << "OK" << endl; - if ((root->right != nullptr) && (root->left != nullptr)) { + if (ict) { + if (root->right && root->left) { root->utensor = Contract(root->left->utensor, root->right->utensor); root->left->clear_utensor(); // remove intermediate unitensor to save heap space root->right->clear_utensor(); // remove intermediate unitensor to save heap space @@ -1139,7 +1123,7 @@ namespace cytnx { } } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // 4. reset nodes: this->CtTree.reset_nodes(); // 6. permute accroding to pre-set labels: @@ -1201,7 +1185,14 @@ namespace cytnx { "\n"); this->tensors.resize(this->names.size()); - this->CtTree.base_nodes.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create nodes using make_shared + for(size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } // checking label matching: map labelcnt; @@ -1282,7 +1273,7 @@ namespace cytnx { vector names; for (int i = 0; i < this->names.size(); i++) { names.push_back(this->names[i]); - CtTree.base_nodes[i].name = this->names[i]; + CtTree.base_nodes[i]->name = this->names[i]; } if (ORDER_tokens.size() != 0) { CtTree.build_contraction_tree_by_tokens(this->name2pos, ORDER_tokens); diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index fc565d1d..7ee88244 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -15,15 +15,16 @@ namespace cytnx { "should contain >=2 tensors in order to build contraction order.%s", "\n"); - Node *left = &(this->base_nodes[0]); - Node *right; + std::shared_ptr left = this->base_nodes[0]; + std::shared_ptr right; this->nodes_container.reserve( this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that // change address. for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { - right = &(this->base_nodes[i]); - this->nodes_container.push_back(Node(left, right)); - left = &(this->nodes_container.back()); + right = this->base_nodes[i]; + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + left = this->nodes_container.back(); } } void ContractionTree::build_contraction_tree_by_tokens( @@ -38,9 +39,8 @@ namespace cytnx { "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", "\n"); - stack stk; - Node *left; - Node *right; + stack> stk; + std::shared_ptr left, right; stack operators; char topc; size_t pos = 0; @@ -68,10 +68,9 @@ namespace cytnx { stk.pop(); left = stk.top(); stk.pop(); - this->nodes_container.push_back(Node(left, right)); - // cout << right->name << " " << left->name <<">"; - // this->nodes_container.back().name = right->name + left->name; - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); if (!operators.empty()) topc = operators.top(); else @@ -90,10 +89,9 @@ namespace cytnx { stk.pop(); left = stk.top(); stk.pop(); - this->nodes_container.push_back(Node(left, right)); - // cout << right->name << " " << left->name << ">"; - // this->nodes_container.back().name = right->name + left->name; - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); if (!operators.empty()) topc = operators.top(); else @@ -112,8 +110,7 @@ namespace cytnx { "contain invalid TN name: %s ,which is not previously defined. \n", tok.c_str()); } - stk.push(&this->base_nodes[idx]); - // cout << "TN" << this->base_nodes[idx].name << endl; + stk.push(this->base_nodes[idx]); } } // for each token @@ -125,8 +122,9 @@ namespace cytnx { left = stk.top(); stk.pop(); // this->nodes_container.back().name = right->name + left->name; - this->nodes_container.push_back(Node(left, right)); - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); } /* cout << "============" << endl; From 2d5f7608efc19b4ce16e7eeec15557311a563c67 Mon Sep 17 00:00:00 2001 From: yjkao Date: Wed, 11 Dec 2024 23:43:19 +0800 Subject: [PATCH 03/11] Enhance ContractionTree and Node classes with improved root pointer management and debugging output - Added debug output for setting root pointers in Node class to aid in tracing. - Introduced a new method `set_root_ptrs()` to streamline root pointer assignment. - Updated constructors to set node names based on children for better identification. - Refactored `reset_nodes()` and `reset_contraction_order()` methods for clearer logic and safety. - Enhanced error handling in `set_root_ptrs()` to manage potential weak pointer issues. - Updated tests to include debug information during network creation. These changes aim to improve the clarity and maintainability of the contraction tree structure, facilitating easier debugging and understanding of tensor relationships. --- include/contraction_tree.hpp | 91 ++++++++++++++++++++++++++---------- src/RegularNetwork.cpp | 2 + src/contraction_tree.cpp | 30 +++++++----- tests/Network_test.cpp | 25 +++++++++- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/include/contraction_tree.hpp b/include/contraction_tree.hpp index b629a7a0..eb2b3755 100644 --- a/include/contraction_tree.hpp +++ b/include/contraction_tree.hpp @@ -9,6 +9,7 @@ #include #include #include +#include // Add for debug output #ifdef BACKEND_TORCH #else @@ -16,11 +17,11 @@ namespace cytnx { /// @cond class Node : public std::enable_shared_from_this { public: - UniTensor utensor; // don't worry about copy, because everything are references in cytnx! + UniTensor utensor; bool is_assigned; std::shared_ptr left; std::shared_ptr right; - std::weak_ptr root; // Use weak_ptr to avoid circular references + std::weak_ptr root; std::string name; Node() : is_assigned(false) {} @@ -31,6 +32,7 @@ namespace cytnx { left(rhs.left), right(rhs.right), name(rhs.name) { + // Only copy root if it exists if (auto r = rhs.root.lock()) { root = r; } @@ -52,26 +54,56 @@ namespace cytnx { Node(std::shared_ptr in_left, std::shared_ptr in_right, const UniTensor& in_uten = UniTensor()) - : is_assigned(false) { - left = in_left; - right = in_right; - if (in_uten.uten_type() != UTenType.Void) { - utensor = in_uten; - } - - // Set root pointers using shared_from_this() - if (left) left->root = shared_from_this(); - if (right) right->root = shared_from_this(); + : is_assigned(false), left(in_left), right(in_right) { + + // Set name based on children + if (left && right) { + name = "(" + left->name + "," + right->name + ")"; + } + + if (in_uten.uten_type() != UTenType.Void) { + utensor = in_uten; + } } - void assign_utensor(const UniTensor& in_uten) { - utensor = in_uten; - is_assigned = true; + void set_root_ptrs() { + try { + auto self = shared_from_this(); + + if (left) { + std::cout << "Setting root for left child of " << name << std::endl; + left->root = self; + left->set_root_ptrs(); + } + + if (right) { + std::cout << "Setting root for right child of " << name << std::endl; + right->root = self; + right->set_root_ptrs(); + } + } catch (const std::bad_weak_ptr& e) { + std::cerr << "Failed to set root ptrs for node " << name + << ": " << e.what() << std::endl; + throw; + } } void clear_utensor() { - is_assigned = false; - utensor = UniTensor(); + if (left) { + left->clear_utensor(); + left->root.reset(); + } + if (right) { + right->clear_utensor(); + right->root.reset(); + } + is_assigned = false; + utensor = UniTensor(); + } + + void assign_utensor(const UniTensor& in_uten) { + utensor = in_uten; + is_assigned = true; } }; @@ -90,19 +122,28 @@ namespace cytnx { } void reset_contraction_order() { - nodes_container.clear(); + // First clear all root pointers for (auto& node : base_nodes) { - node->root.reset(); // Dereference shared_ptr with -> + if (node) node->root.reset(); } + // Then clear the container + nodes_container.clear(); } void reset_nodes() { - for (auto& node : nodes_container) { - node->clear_utensor(); - } - for (auto& node : base_nodes) { - node->clear_utensor(); - } + // Clear from root down if we have nodes + if (!nodes_container.empty() && nodes_container.back()) { + nodes_container.back()->clear_utensor(); + } + nodes_container.clear(); + + // Reset base nodes + for (auto& node : base_nodes) { + if (node) { + node->is_assigned = false; + node->utensor = UniTensor(); + } + } } void build_default_contraction_tree(); diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index 620d063d..d76806d1 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -971,6 +971,7 @@ namespace cytnx { // cout << this->CtTree.nodes_container.size() << endl; stack> stk; std::shared_ptr root = this->CtTree.nodes_container.back(); + root->set_root_ptrs(); // Add this line int ly = 0; bool ict; @@ -1088,6 +1089,7 @@ namespace cytnx { // cout << this->CtTree.nodes_container.size() << endl; stack> stk; std::shared_ptr root = std::make_shared(this->CtTree.nodes_container.back()); + root->set_root_ptrs(); // Add this line int ly = 0; bool ict; diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index 7ee88244..edf05bfa 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -10,21 +10,29 @@ using namespace std; namespace cytnx { void ContractionTree::build_default_contraction_tree() { this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, - "[ERROR][ContractionTree][build_default_contraction_order] contraction tree " - "should contain >=2 tensors in order to build contraction order.%s", - "\n"); + + + cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction","\n"); + std::shared_ptr left = this->base_nodes[0]; std::shared_ptr right; - this->nodes_container.reserve( - this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that - // change address. + + this->nodes_container.reserve(this->base_nodes.size()); + for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { - right = this->base_nodes[i]; - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - left = this->nodes_container.back(); + right = this->base_nodes[i]; + + auto new_node = std::make_shared(left, right); + + this->nodes_container.push_back(new_node); + left = new_node; + } + + if (!nodes_container.empty()) { + auto root = nodes_container.back(); + std::cout << "Setting root pointers from " << root->name << std::endl; + root->set_root_ptrs(); } } void ContractionTree::build_contraction_tree_by_tokens( diff --git a/tests/Network_test.cpp b/tests/Network_test.cpp index a13763c9..79f54987 100644 --- a/tests/Network_test.cpp +++ b/tests/Network_test.cpp @@ -19,8 +19,29 @@ // } TEST_F(NetworkTest, Network_dense_FromString) { - auto net = Network(); - net.FromString({"A: a,b,c", "B: c,d", "C: d,e", "ORDER:(A,(B,C))", "TOUT: a,b;e"}); + try { + std::cout << "Creating network..." << std::endl; + auto net = Network(); + + std::vector network_def = { + "A: a,b,c", + "B: c,d", + "C: d,e", + "ORDER:(A,(B,C))", + "TOUT: a,b;e" + }; + + for(const auto& def : network_def) { + std::cout << "Processing: " << def << std::endl; + } + + net.FromString(network_def); + std::cout << "Network construction successful" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + throw; + } } TEST_F(NetworkTest, Network_dense_no_order) { From f39a1373bcae24d4524f95318e331e5ff9bc8e9b Mon Sep 17 00:00:00 2001 From: yjkao Date: Thu, 12 Dec 2024 20:05:25 +0800 Subject: [PATCH 04/11] Refactor base node creation in RegularNetwork class for improved clarity and functionality - Clear previous base nodes before creating new ones to ensure a fresh state. - Implemented a loop to properly create and assign names to base nodes using smart pointers. - Enhanced the structure of the FromString method for better readability. These changes aim to streamline the initialization of base nodes in the contraction tree, improving maintainability and clarity in the network setup process. --- src/RegularNetwork.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index d76806d1..4a922217 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -536,6 +536,14 @@ namespace cytnx { this->tensors.resize(this->names.size()); this->CtTree.base_nodes.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create base nodes properly + for(size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } // checking if all TN are set in ORDER. if (isORDER_exist) { @@ -640,7 +648,7 @@ namespace cytnx { CtTree.build_default_contraction_tree(); } this->einsum_path = CtTree_to_eisumpath(CtTree, names); - } + } // end of FromString void RegularNetwork::Fromfile(const string &fname) { const cytnx_uint64 MAXLINES = 1024; From 0cec911417469ef448c07de8edf565445ef8b5a5 Mon Sep 17 00:00:00 2001 From: yjkao Date: Thu, 12 Dec 2024 20:32:29 +0800 Subject: [PATCH 05/11] Fix Clang-format error --- src/contraction_tree.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index edf05bfa..edef8c20 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -12,7 +12,8 @@ namespace cytnx { this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction","\n"); + cytnx_error_msg(this->base_nodes.size() < 2, + "[ERROR] Need at least 2 tensors for contraction","\n"); std::shared_ptr left = this->base_nodes[0]; From e783b55b641adf3bc71242049a7b857973fcb890 Mon Sep 17 00:00:00 2001 From: yjkao Date: Thu, 12 Dec 2024 20:35:55 +0800 Subject: [PATCH 06/11] Fix Clang-Format error --- src/contraction_tree.cpp | 255 +++++++++++++++++++-------------------- 1 file changed, 127 insertions(+), 128 deletions(-) diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index edef8c20..6282a819 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -8,141 +8,140 @@ using namespace std; #else namespace cytnx { - void ContractionTree::build_default_contraction_tree() { - this->reset_contraction_order(); - - - cytnx_error_msg(this->base_nodes.size() < 2, - "[ERROR] Need at least 2 tensors for contraction","\n"); - - - std::shared_ptr left = this->base_nodes[0]; - std::shared_ptr right; - - this->nodes_container.reserve(this->base_nodes.size()); - - for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { - right = this->base_nodes[i]; - - auto new_node = std::make_shared(left, right); - - this->nodes_container.push_back(new_node); - left = new_node; - } +void ContractionTree::build_default_contraction_tree() { + this->reset_contraction_order(); - if (!nodes_container.empty()) { - auto root = nodes_container.back(); - std::cout << "Setting root pointers from " << root->name << std::endl; - root->set_root_ptrs(); - } + cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction", + "\n"); + + std::shared_ptr left = this->base_nodes[0]; + std::shared_ptr right; + + this->nodes_container.reserve(this->base_nodes.size()); + + for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { + right = this->base_nodes[i]; + + auto new_node = std::make_shared(left, right); + + this->nodes_container.push_back(new_node); + left = new_node; } - void ContractionTree::build_contraction_tree_by_tokens( - const std::map &name2pos, const std::vector &tokens) { - this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, - "[ERROR][ContractionTree][build_contraction_order_by_tokens] contraction tree " - "should contain >=2 tensors in order to build contraction order.%s", - "\n"); - cytnx_error_msg( - tokens.size() == 0, - "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", - "\n"); - - stack> stk; - std::shared_ptr left, right; - stack operators; - char topc; - size_t pos = 0; - std::string tok; - - // evaluate each token, and construct the Contraction Tree. - this->nodes_container.reserve( - this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that - // change address. - for (cytnx_uint64 i = 0; i < tokens.size(); i++) { - tok = str_strip(tokens[i]); // remove space. - // cout << tokens[i] << "|" << tok << "|" << endl; - if (tok.length() == 0) continue; - // cout << tok << "|"; - if (tok == "(") { - operators.push(tok.c_str()[0]); - // cout << "put(" << endl; - } else if (tok == ")") { - // cout << "put)-->"; - if (!operators.empty()) { - topc = operators.top(); - while ((topc != '(')) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - if (!operators.empty()) - topc = operators.top(); - else - break; - } - } - // cout << endl; - operators.pop(); // discard the '(' - } else if (tok == ",") { - // cout << "put,-->"; - if (!operators.empty()) { - topc = operators.top(); - while ((topc != '(') && (topc != ')')) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - if (!operators.empty()) - topc = operators.top(); - else - break; - } + + if (!nodes_container.empty()) { + auto root = nodes_container.back(); + std::cout << "Setting root pointers from " << root->name << std::endl; + root->set_root_ptrs(); + } +} + +void ContractionTree::build_contraction_tree_by_tokens( + const std::map &name2pos, const std::vector &tokens) { + this->reset_contraction_order(); + cytnx_error_msg(this->base_nodes.size() < 2, + "[ERROR][ContractionTree][build_contraction_order_by_tokens] contraction tree " + "should contain >=2 tensors in order to build contraction order.%s", + "\n"); + cytnx_error_msg( + tokens.size() == 0, + "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", + "\n"); + + stack> stk; + std::shared_ptr left, right; + stack operators; + char topc; + size_t pos = 0; + std::string tok; + + // evaluate each token, and construct the Contraction Tree. + this->nodes_container.reserve( + this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that + // change address. + for (cytnx_uint64 i = 0; i < tokens.size(); i++) { + tok = str_strip(tokens[i]); // remove space. + // cout << tokens[i] << "|" << tok << "|" << endl; + if (tok.length() == 0) continue; + // cout << tok << "|"; + if (tok == "(") { + operators.push(tok.c_str()[0]); + // cout << "put(" << endl; + } else if (tok == ")") { + // cout << "put)-->"; + if (!operators.empty()) { + topc = operators.top(); + while ((topc != '(')) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + if (!operators.empty()) + topc = operators.top(); + else + break; } - // cout << endl; - operators.push(','); - } else { - cytnx_uint64 idx; - try { - idx = name2pos.at(tok); - } catch (std::out_of_range) { - cytnx_error_msg(true, - "[ERROR][ContractionTree][build_contraction_order_by_token] tokens " - "contain invalid TN name: %s ,which is not previously defined. \n", - tok.c_str()); + } + // cout << endl; + operators.pop(); // discard the '(' + } else if (tok == ",") { + // cout << "put,-->"; + if (!operators.empty()) { + topc = operators.top(); + while ((topc != '(') && (topc != ')')) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + if (!operators.empty()) + topc = operators.top(); + else + break; } - stk.push(this->base_nodes[idx]); } - - } // for each token - - while (!operators.empty()) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - // this->nodes_container.back().name = right->name + left->name; - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - } - /* - cout << "============" << endl; - for(int i=0;inodes_container.size();i++){ - cout << this->nodes_container[i].name << endl; + // cout << endl; + operators.push(','); + } else { + cytnx_uint64 idx; + try { + idx = name2pos.at(tok); + } catch (std::out_of_range) { + cytnx_error_msg(true, + "[ERROR][ContractionTree][build_contraction_order_by_token] tokens " + "contain invalid TN name: %s ,which is not previously defined. \n", + tok.c_str()); + } + stk.push(this->base_nodes[idx]); } - cout << "============" << endl; - */ + + } // for each token + + while (!operators.empty()) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + // this->nodes_container.back().name = right->name + left->name; + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + } + /* + cout << "============" << endl; + for(int i=0;inodes_container.size();i++){ + cout << this->nodes_container[i].name << endl; } + cout << "============" << endl; + */ +} } // namespace cytnx #endif From 31dcf7db0104047b332f511bc53138cbdcb072df Mon Sep 17 00:00:00 2001 From: yjkao Date: Thu, 12 Dec 2024 20:48:40 +0800 Subject: [PATCH 07/11] Reformat code --- include/Network.hpp | 34 ++--- include/contraction_tree.hpp | 109 ++++++++-------- src/RegularNetwork.cpp | 18 +-- src/contraction_tree.cpp | 242 +++++++++++++++++------------------ src/search_tree.cpp | 38 +++--- tests/Network_test.cpp | 16 +-- 6 files changed, 220 insertions(+), 237 deletions(-) diff --git a/include/Network.hpp b/include/Network.hpp index 2a240d9d..c5560aeb 100644 --- a/include/Network.hpp +++ b/include/Network.hpp @@ -86,7 +86,7 @@ namespace cytnx { friend class FermionNetwork; friend class RegularNetwork; friend class Network; - Network_base() : nwrktype_id(NtType.Void){}; + Network_base() : nwrktype_id(NtType.Void) {}; bool HasPutAllUniTensor() { for (cytnx_uint64 i = 0; i < this->tensors.size(); i++) { @@ -138,7 +138,7 @@ namespace cytnx { virtual void PrintNet(std::ostream &os); virtual boost::intrusive_ptr clone(); virtual void Savefile(const std::string &fname); - virtual ~Network_base(){}; + virtual ~Network_base() {}; }; // Network_base @@ -196,7 +196,7 @@ namespace cytnx { } void PrintNet(std::ostream &os); void Savefile(const std::string &fname); - ~RegularNetwork(){}; + ~RegularNetwork() {}; }; // Under dev!! @@ -206,19 +206,19 @@ namespace cytnx { public: FermionNetwork() { this->nwrktype_id = NtType.Fermion; }; - void Fromfile(const std::string &fname){}; - void FromString(const std::vector &contents){}; - void RmUniTensor(const cytnx_uint64 &idx){}; - void RmUniTensor(const std::string &name){}; - void RmUniTensors(const std::vector &name){}; - - void PutUniTensor(const std::string &name, const UniTensor &utensor){}; - void PutUniTensor(const cytnx_uint64 &idx, const UniTensor &utensor){}; + void Fromfile(const std::string &fname) {}; + void FromString(const std::vector &contents) {}; + void RmUniTensor(const cytnx_uint64 &idx) {}; + void RmUniTensor(const std::string &name) {}; + void RmUniTensors(const std::vector &name) {}; + + void PutUniTensor(const std::string &name, const UniTensor &utensor) {}; + void PutUniTensor(const cytnx_uint64 &idx, const UniTensor &utensor) {}; void PutUniTensors(const std::vector &name, - const std::vector &utensors){}; + const std::vector &utensors) {}; void Contract_plan(const std::vector &utensors, const std::string &Tout, const std::vector &alias = {}, - const std::string &contract_order = ""){}; + const std::string &contract_order = "") {}; void clear() { this->name2pos.clear(); this->CtTree.clear(); @@ -245,9 +245,9 @@ namespace cytnx { boost::intrusive_ptr out(tmp); return out; } - void PrintNet(std::ostream &os){}; - void Savefile(const std::string &fname){}; - ~FermionNetwork(){}; + void PrintNet(std::ostream &os) {}; + void Savefile(const std::string &fname) {}; + ~FermionNetwork() {}; }; ///@endcond @@ -262,7 +262,7 @@ namespace cytnx { public: ///@cond boost::intrusive_ptr _impl; - Network() : _impl(new Network_base()){}; + Network() : _impl(new Network_base()) {}; Network(const Network &rhs) { this->_impl = rhs._impl; } Network &operator=(const Network &rhs) { this->_impl = rhs._impl; diff --git a/include/contraction_tree.hpp b/include/contraction_tree.hpp index eb2b3755..a0140005 100644 --- a/include/contraction_tree.hpp +++ b/include/contraction_tree.hpp @@ -25,8 +25,8 @@ namespace cytnx { std::string name; Node() : is_assigned(false) {} - - Node(const Node& rhs) + + Node(const Node& rhs) : utensor(rhs.utensor), is_assigned(rhs.is_assigned), left(rhs.left), @@ -37,7 +37,7 @@ namespace cytnx { root = r; } } - + Node& operator=(const Node& rhs) { if (this != &rhs) { utensor = rhs.utensor; @@ -52,53 +52,51 @@ namespace cytnx { return *this; } - Node(std::shared_ptr in_left, std::shared_ptr in_right, + Node(std::shared_ptr in_left, std::shared_ptr in_right, const UniTensor& in_uten = UniTensor()) : is_assigned(false), left(in_left), right(in_right) { - - // Set name based on children - if (left && right) { - name = "(" + left->name + "," + right->name + ")"; - } - - if (in_uten.uten_type() != UTenType.Void) { - utensor = in_uten; - } + // Set name based on children + if (left && right) { + name = "(" + left->name + "," + right->name + ")"; + } + + if (in_uten.uten_type() != UTenType.Void) { + utensor = in_uten; + } } void set_root_ptrs() { - try { - auto self = shared_from_this(); - - if (left) { - std::cout << "Setting root for left child of " << name << std::endl; - left->root = self; - left->set_root_ptrs(); - } - - if (right) { - std::cout << "Setting root for right child of " << name << std::endl; - right->root = self; - right->set_root_ptrs(); - } - } catch (const std::bad_weak_ptr& e) { - std::cerr << "Failed to set root ptrs for node " << name - << ": " << e.what() << std::endl; - throw; - } - } + try { + auto self = shared_from_this(); - void clear_utensor() { if (left) { - left->clear_utensor(); - left->root.reset(); + std::cout << "Setting root for left child of " << name << std::endl; + left->root = self; + left->set_root_ptrs(); } + if (right) { - right->clear_utensor(); - right->root.reset(); + std::cout << "Setting root for right child of " << name << std::endl; + right->root = self; + right->set_root_ptrs(); } - is_assigned = false; - utensor = UniTensor(); + } catch (const std::bad_weak_ptr& e) { + std::cerr << "Failed to set root ptrs for node " << name << ": " << e.what() << std::endl; + throw; + } + } + + void clear_utensor() { + if (left) { + left->clear_utensor(); + left->root.reset(); + } + if (right) { + right->clear_utensor(); + right->root.reset(); + } + is_assigned = false; + utensor = UniTensor(); } void assign_utensor(const UniTensor& in_uten) { @@ -110,7 +108,7 @@ namespace cytnx { class ContractionTree { public: std::vector> nodes_container; // intermediate layer - std::vector> base_nodes; // bottom layer + std::vector> base_nodes; // bottom layer ContractionTree() = default; ContractionTree(const ContractionTree&) = default; @@ -131,25 +129,24 @@ namespace cytnx { } void reset_nodes() { - // Clear from root down if we have nodes - if (!nodes_container.empty() && nodes_container.back()) { - nodes_container.back()->clear_utensor(); - } - nodes_container.clear(); - - // Reset base nodes - for (auto& node : base_nodes) { - if (node) { - node->is_assigned = false; - node->utensor = UniTensor(); - } + // Clear from root down if we have nodes + if (!nodes_container.empty() && nodes_container.back()) { + nodes_container.back()->clear_utensor(); + } + nodes_container.clear(); + + // Reset base nodes + for (auto& node : base_nodes) { + if (node) { + node->is_assigned = false; + node->utensor = UniTensor(); } + } } void build_default_contraction_tree(); - void build_contraction_tree_by_tokens( - const std::map& name2pos, - const std::vector& tokens); + void build_contraction_tree_by_tokens(const std::map& name2pos, + const std::vector& tokens); }; /// @endcond } // namespace cytnx diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index 4a922217..049b6b8e 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -537,12 +537,12 @@ namespace cytnx { this->tensors.resize(this->names.size()); this->CtTree.base_nodes.resize(this->names.size()); this->CtTree.base_nodes.clear(); - + // Create base nodes properly - for(size_t i = 0; i < this->names.size(); i++) { - auto node = std::make_shared(); - node->name = this->names[i]; - this->CtTree.base_nodes.push_back(node); + for (size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); } // checking if all TN are set in ORDER. @@ -648,7 +648,7 @@ namespace cytnx { CtTree.build_default_contraction_tree(); } this->einsum_path = CtTree_to_eisumpath(CtTree, names); - } // end of FromString + } // end of FromString void RegularNetwork::Fromfile(const string &fname) { const cytnx_uint64 MAXLINES = 1024; @@ -993,7 +993,7 @@ namespace cytnx { root = stk.top(); stk.pop(); - + ict = true; if (root->right && !stk.empty()) { if (stk.top() == root->right) { // This comparison now works with shared_ptr @@ -1111,7 +1111,7 @@ namespace cytnx { root = stk.top(); stk.pop(); - + ict = true; if (root->right && !stk.empty()) { if (stk.top() == root->right) { // This comparison now works with shared_ptr @@ -1198,7 +1198,7 @@ namespace cytnx { this->CtTree.base_nodes.clear(); // Create nodes using make_shared - for(size_t i = 0; i < this->names.size(); i++) { + for (size_t i = 0; i < this->names.size(); i++) { auto node = std::make_shared(); node->name = this->names[i]; this->CtTree.base_nodes.push_back(node); diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index 6282a819..e4c3667b 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -8,140 +8,140 @@ using namespace std; #else namespace cytnx { -void ContractionTree::build_default_contraction_tree() { - this->reset_contraction_order(); + void ContractionTree::build_default_contraction_tree() { + this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction", - "\n"); + cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction", + "\n"); - std::shared_ptr left = this->base_nodes[0]; - std::shared_ptr right; + std::shared_ptr left = this->base_nodes[0]; + std::shared_ptr right; - this->nodes_container.reserve(this->base_nodes.size()); + this->nodes_container.reserve(this->base_nodes.size()); - for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { - right = this->base_nodes[i]; + for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { + right = this->base_nodes[i]; - auto new_node = std::make_shared(left, right); + auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - left = new_node; - } + this->nodes_container.push_back(new_node); + left = new_node; + } - if (!nodes_container.empty()) { - auto root = nodes_container.back(); - std::cout << "Setting root pointers from " << root->name << std::endl; - root->set_root_ptrs(); + if (!nodes_container.empty()) { + auto root = nodes_container.back(); + std::cout << "Setting root pointers from " << root->name << std::endl; + root->set_root_ptrs(); + } } -} - -void ContractionTree::build_contraction_tree_by_tokens( - const std::map &name2pos, const std::vector &tokens) { - this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, - "[ERROR][ContractionTree][build_contraction_order_by_tokens] contraction tree " - "should contain >=2 tensors in order to build contraction order.%s", - "\n"); - cytnx_error_msg( - tokens.size() == 0, - "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", - "\n"); - - stack> stk; - std::shared_ptr left, right; - stack operators; - char topc; - size_t pos = 0; - std::string tok; - - // evaluate each token, and construct the Contraction Tree. - this->nodes_container.reserve( - this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that - // change address. - for (cytnx_uint64 i = 0; i < tokens.size(); i++) { - tok = str_strip(tokens[i]); // remove space. - // cout << tokens[i] << "|" << tok << "|" << endl; - if (tok.length() == 0) continue; - // cout << tok << "|"; - if (tok == "(") { - operators.push(tok.c_str()[0]); - // cout << "put(" << endl; - } else if (tok == ")") { - // cout << "put)-->"; - if (!operators.empty()) { - topc = operators.top(); - while ((topc != '(')) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - if (!operators.empty()) - topc = operators.top(); - else - break; + + void ContractionTree::build_contraction_tree_by_tokens( + const std::map &name2pos, const std::vector &tokens) { + this->reset_contraction_order(); + cytnx_error_msg(this->base_nodes.size() < 2, + "[ERROR][ContractionTree][build_contraction_order_by_tokens] contraction tree " + "should contain >=2 tensors in order to build contraction order.%s", + "\n"); + cytnx_error_msg( + tokens.size() == 0, + "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", + "\n"); + + stack> stk; + std::shared_ptr left, right; + stack operators; + char topc; + size_t pos = 0; + std::string tok; + + // evaluate each token, and construct the Contraction Tree. + this->nodes_container.reserve( + this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that + // change address. + for (cytnx_uint64 i = 0; i < tokens.size(); i++) { + tok = str_strip(tokens[i]); // remove space. + // cout << tokens[i] << "|" << tok << "|" << endl; + if (tok.length() == 0) continue; + // cout << tok << "|"; + if (tok == "(") { + operators.push(tok.c_str()[0]); + // cout << "put(" << endl; + } else if (tok == ")") { + // cout << "put)-->"; + if (!operators.empty()) { + topc = operators.top(); + while ((topc != '(')) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + if (!operators.empty()) + topc = operators.top(); + else + break; + } } - } - // cout << endl; - operators.pop(); // discard the '(' - } else if (tok == ",") { - // cout << "put,-->"; - if (!operators.empty()) { - topc = operators.top(); - while ((topc != '(') && (topc != ')')) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - if (!operators.empty()) - topc = operators.top(); - else - break; + // cout << endl; + operators.pop(); // discard the '(' + } else if (tok == ",") { + // cout << "put,-->"; + if (!operators.empty()) { + topc = operators.top(); + while ((topc != '(') && (topc != ')')) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + if (!operators.empty()) + topc = operators.top(); + else + break; + } } + // cout << endl; + operators.push(','); + } else { + cytnx_uint64 idx; + try { + idx = name2pos.at(tok); + } catch (std::out_of_range) { + cytnx_error_msg(true, + "[ERROR][ContractionTree][build_contraction_order_by_token] tokens " + "contain invalid TN name: %s ,which is not previously defined. \n", + tok.c_str()); + } + stk.push(this->base_nodes[idx]); } - // cout << endl; - operators.push(','); - } else { - cytnx_uint64 idx; - try { - idx = name2pos.at(tok); - } catch (std::out_of_range) { - cytnx_error_msg(true, - "[ERROR][ContractionTree][build_contraction_order_by_token] tokens " - "contain invalid TN name: %s ,which is not previously defined. \n", - tok.c_str()); - } - stk.push(this->base_nodes[idx]); - } - } // for each token - - while (!operators.empty()) { - operators.pop(); - right = stk.top(); - stk.pop(); - left = stk.top(); - stk.pop(); - // this->nodes_container.back().name = right->name + left->name; - auto new_node = std::make_shared(left, right); - this->nodes_container.push_back(new_node); - stk.push(this->nodes_container.back()); - } - /* - cout << "============" << endl; - for(int i=0;inodes_container.size();i++){ - cout << this->nodes_container[i].name << endl; + } // for each token + + while (!operators.empty()) { + operators.pop(); + right = stk.top(); + stk.pop(); + left = stk.top(); + stk.pop(); + // this->nodes_container.back().name = right->name + left->name; + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); + } + /* + cout << "============" << endl; + for(int i=0;inodes_container.size();i++){ + cout << this->nodes_container[i].name << endl; + } + cout << "============" << endl; + */ } - cout << "============" << endl; - */ -} } // namespace cytnx #endif diff --git a/src/search_tree.cpp b/src/search_tree.cpp index 385b8baa..c4669e4c 100644 --- a/src/search_tree.cpp +++ b/src/search_tree.cpp @@ -32,40 +32,38 @@ namespace cytnx { PseudoUniTensor pContract(PseudoUniTensor& t1, PseudoUniTensor& t2) { PseudoUniTensor t3(0); // Initialize with index 0 - + t3.ID = t1.ID ^ t2.ID; // XOR of IDs to track contracted tensors t3.cost = get_cost(t1, t2); // Calculate contraction cost - + // Find common labels between t1 and t2 vector loc1, loc2; vector comm_lbl; vec_intersect_(comm_lbl, t1.labels, t2.labels, loc1, loc2); - + // New shape is concatenation of non-contracted dimensions - t3.shape = vec_concatenate(vec_erase(t1.shape, loc1), - vec_erase(t2.shape, loc2)); - - // New labels are concatenation of non-contracted labels - t3.labels = vec_concatenate(vec_erase(t1.labels, loc1), - vec_erase(t2.labels, loc2)); - + t3.shape = vec_concatenate(vec_erase(t1.shape, loc1), vec_erase(t2.shape, loc2)); + + // New labels are concatenation of non-contracted labels + t3.labels = vec_concatenate(vec_erase(t1.labels, loc1), vec_erase(t2.labels, loc2)); + // Set accumulation string using the original accu_str if available if (t1.accu_str.empty()) t1.accu_str = std::to_string(t1.tensorIndex); if (t2.accu_str.empty()) t2.accu_str = std::to_string(t2.tensorIndex); t3.accu_str = "(" + t1.accu_str + "," + t2.accu_str + ")"; - + // Set as internal node t3.isLeaf = false; t3.left = std::make_unique(t1); t3.right = std::make_unique(t2); - + return t3; } namespace OptimalTreeSolver { // Helper function to find connected components using DFS - void dfs(size_t node, const std::vector& adjacencyMatrix, - IndexSet& visited, std::vector& component) { + void dfs(size_t node, const std::vector& adjacencyMatrix, IndexSet& visited, + std::vector& component) { visited.set(node); component.push_back(node); @@ -112,7 +110,7 @@ namespace cytnx { const size_t n = nodes.size(); // Build adjacency matrix with proper size std::vector adjacencyMatrix(n); - + // Fill adjacency matrix for (size_t i = 0; i < n; ++i) { for (size_t j = i + 1; j < n; ++j) { @@ -191,16 +189,16 @@ namespace cytnx { // Create new node for combining components auto new_node = std::make_unique(); new_node->isLeaf = false; - + // Move the first two components as children new_node->left = std::move(component_results[0]); new_node->right = std::move(component_results[1]); - + // Calculate cost and set properties new_node->cost = get_cost(*new_node->left, *new_node->right); new_node->accu_str = "(" + new_node->left->accu_str + "," + new_node->right->accu_str + ")"; new_node->ID = new_node->left->ID ^ new_node->right->ID; - + // Update component list component_results.erase(component_results.begin(), component_results.begin() + 2); component_results.insert(component_results.begin(), std::move(new_node)); @@ -212,7 +210,7 @@ namespace cytnx { void SearchTree::search_order() { this->reset_search_order(); - if (this->base_nodes.size() == 1 || this->base_nodes.size() == 0 ) { + if (this->base_nodes.size() == 1 || this->base_nodes.size() == 0) { cytnx_error_msg(true, "[ERROR][SearchTree] need at least 2 nodes.%s", "\n"); } @@ -299,7 +297,5 @@ namespace cytnx { right = nullptr; } - - } // namespace cytnx #endif diff --git a/tests/Network_test.cpp b/tests/Network_test.cpp index 79f54987..6265b95c 100644 --- a/tests/Network_test.cpp +++ b/tests/Network_test.cpp @@ -19,8 +19,7 @@ // } TEST_F(NetworkTest, Network_dense_FromString) { - try { - std::cout << "Creating network..." << std::endl; + auto net = Network(); std::vector network_def = { @@ -30,18 +29,9 @@ TEST_F(NetworkTest, Network_dense_FromString) { "ORDER:(A,(B,C))", "TOUT: a,b;e" }; - - for(const auto& def : network_def) { - std::cout << "Processing: " << def << std::endl; - } - + net.FromString(network_def); - std::cout << "Network construction successful" << std::endl; - - } catch (const std::exception& e) { - std::cerr << "Exception: " << e.what() << std::endl; - throw; - } + } TEST_F(NetworkTest, Network_dense_no_order) { From 68d03f6c66de10a92a16d866751f092292be2949 Mon Sep 17 00:00:00 2001 From: yjkao Date: Fri, 13 Dec 2024 10:26:33 +0800 Subject: [PATCH 08/11] Fix Clang-format --- include/search_tree.hpp | 15 ++++++++------- src/RegularGncon.cpp | 6 +++--- tests/Network_test.cpp | 15 ++++----------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/include/search_tree.hpp b/include/search_tree.hpp index fa8c7767..82415b1d 100644 --- a/include/search_tree.hpp +++ b/include/search_tree.hpp @@ -36,7 +36,11 @@ namespace cytnx { // Constructors explicit PseudoUniTensor(cytnx_uint64 index = 0) - : isLeaf(true), tensorIndex(index), is_assigned(false), cost(0), ID(1ULL << index), + : isLeaf(true), + tensorIndex(index), + is_assigned(false), + cost(0), + ID(1ULL << index), accu_str(std::to_string(index)) {} PseudoUniTensor(std::unique_ptr l, std::unique_ptr r) @@ -60,20 +64,17 @@ namespace cytnx { class SearchTree { public: - std::vector base_nodes; SearchTree() = default; void clear() { - root_ptr.reset(); - base_nodes.clear(); + root_ptr.reset(); + base_nodes.clear(); } void reset_search_order() { root_ptr.reset(); } void search_order(); - std::vector> get_root() const { - return {{root_ptr.get()}}; - } + std::vector> get_root() const { return {{root_ptr.get()}}; } private: std::unique_ptr root_ptr; diff --git a/src/RegularGncon.cpp b/src/RegularGncon.cpp index bb59109c..8060a95e 100644 --- a/src/RegularGncon.cpp +++ b/src/RegularGncon.cpp @@ -303,9 +303,9 @@ namespace cytnx { // Update node creation this->tensors.resize(this->names.size()); this->CtTree.base_nodes.clear(); - + // Create nodes using make_shared - for(size_t i = 0; i < this->names.size(); i++) { + for (size_t i = 0; i < this->names.size(); i++) { auto node = std::make_shared(); node->name = this->names[i]; this->CtTree.base_nodes.push_back(node); @@ -610,7 +610,7 @@ namespace cytnx { SearchTree Stree; Stree.base_nodes.clear(); Stree.base_nodes.resize(this->tensors.size()); - + for (cytnx_uint64 t = 0; t < this->tensors.size(); t++) { Stree.base_nodes[t].from_utensor(this->tensors[t]); Stree.base_nodes[t].accu_str = this->names[t]; diff --git a/tests/Network_test.cpp b/tests/Network_test.cpp index 6265b95c..dab32bfa 100644 --- a/tests/Network_test.cpp +++ b/tests/Network_test.cpp @@ -19,19 +19,12 @@ // } TEST_F(NetworkTest, Network_dense_FromString) { + auto net = Network(); - auto net = Network(); - - std::vector network_def = { - "A: a,b,c", - "B: c,d", - "C: d,e", - "ORDER:(A,(B,C))", - "TOUT: a,b;e" - }; - - net.FromString(network_def); + std::vector network_def = {"A: a,b,c", "B: c,d", "C: d,e", "ORDER:(A,(B,C))", + "TOUT: a,b;e"}; + net.FromString(network_def); } TEST_F(NetworkTest, Network_dense_no_order) { From 25bd012afe9ecea67f5dc15f9a24b916a14bcdff Mon Sep 17 00:00:00 2001 From: yjkao Date: Fri, 13 Dec 2024 11:37:25 +0800 Subject: [PATCH 09/11] Added missing include for cytnx.hpp in Det_test.cpp to ensure proper functionality of linear algebra tests. --- tests/linalg_test/Det_test.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/linalg_test/Det_test.cpp b/tests/linalg_test/Det_test.cpp index 3d8bba24..6d61c574 100644 --- a/tests/linalg_test/Det_test.cpp +++ b/tests/linalg_test/Det_test.cpp @@ -1,7 +1,8 @@ #include #include "../test_tools.h" - +#include "cytnx.hpp" + using namespace cytnx; using namespace testing; using namespace TestTools; From adc303fc2375c81b501713a26b83d05b5b4e252c Mon Sep 17 00:00:00 2001 From: yjkao Date: Fri, 13 Dec 2024 12:12:20 +0800 Subject: [PATCH 10/11] Fix Clang-format --- include/Network.hpp | 34 +++++++++++++++++----------------- tests/linalg_test/Det_test.cpp | 4 ++-- tests/search_tree_test.cpp | 6 +++--- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/Network.hpp b/include/Network.hpp index c5560aeb..2a240d9d 100644 --- a/include/Network.hpp +++ b/include/Network.hpp @@ -86,7 +86,7 @@ namespace cytnx { friend class FermionNetwork; friend class RegularNetwork; friend class Network; - Network_base() : nwrktype_id(NtType.Void) {}; + Network_base() : nwrktype_id(NtType.Void){}; bool HasPutAllUniTensor() { for (cytnx_uint64 i = 0; i < this->tensors.size(); i++) { @@ -138,7 +138,7 @@ namespace cytnx { virtual void PrintNet(std::ostream &os); virtual boost::intrusive_ptr clone(); virtual void Savefile(const std::string &fname); - virtual ~Network_base() {}; + virtual ~Network_base(){}; }; // Network_base @@ -196,7 +196,7 @@ namespace cytnx { } void PrintNet(std::ostream &os); void Savefile(const std::string &fname); - ~RegularNetwork() {}; + ~RegularNetwork(){}; }; // Under dev!! @@ -206,19 +206,19 @@ namespace cytnx { public: FermionNetwork() { this->nwrktype_id = NtType.Fermion; }; - void Fromfile(const std::string &fname) {}; - void FromString(const std::vector &contents) {}; - void RmUniTensor(const cytnx_uint64 &idx) {}; - void RmUniTensor(const std::string &name) {}; - void RmUniTensors(const std::vector &name) {}; - - void PutUniTensor(const std::string &name, const UniTensor &utensor) {}; - void PutUniTensor(const cytnx_uint64 &idx, const UniTensor &utensor) {}; + void Fromfile(const std::string &fname){}; + void FromString(const std::vector &contents){}; + void RmUniTensor(const cytnx_uint64 &idx){}; + void RmUniTensor(const std::string &name){}; + void RmUniTensors(const std::vector &name){}; + + void PutUniTensor(const std::string &name, const UniTensor &utensor){}; + void PutUniTensor(const cytnx_uint64 &idx, const UniTensor &utensor){}; void PutUniTensors(const std::vector &name, - const std::vector &utensors) {}; + const std::vector &utensors){}; void Contract_plan(const std::vector &utensors, const std::string &Tout, const std::vector &alias = {}, - const std::string &contract_order = "") {}; + const std::string &contract_order = ""){}; void clear() { this->name2pos.clear(); this->CtTree.clear(); @@ -245,9 +245,9 @@ namespace cytnx { boost::intrusive_ptr out(tmp); return out; } - void PrintNet(std::ostream &os) {}; - void Savefile(const std::string &fname) {}; - ~FermionNetwork() {}; + void PrintNet(std::ostream &os){}; + void Savefile(const std::string &fname){}; + ~FermionNetwork(){}; }; ///@endcond @@ -262,7 +262,7 @@ namespace cytnx { public: ///@cond boost::intrusive_ptr _impl; - Network() : _impl(new Network_base()) {}; + Network() : _impl(new Network_base()){}; Network(const Network &rhs) { this->_impl = rhs._impl; } Network &operator=(const Network &rhs) { this->_impl = rhs._impl; diff --git a/tests/linalg_test/Det_test.cpp b/tests/linalg_test/Det_test.cpp index 6d61c574..fb2eca79 100644 --- a/tests/linalg_test/Det_test.cpp +++ b/tests/linalg_test/Det_test.cpp @@ -1,8 +1,8 @@ #include #include "../test_tools.h" -#include "cytnx.hpp" - +#include "cytnx.hpp" + using namespace cytnx; using namespace testing; using namespace TestTools; diff --git a/tests/search_tree_test.cpp b/tests/search_tree_test.cpp index 8ab0c772..6133fc15 100644 --- a/tests/search_tree_test.cpp +++ b/tests/search_tree_test.cpp @@ -50,7 +50,7 @@ TEST_F(SearchTreeTest, BasicSearchOrder) { EXPECT_EQ(result->cost, 32); // 2*3*4 + 2*2*4 = 24 + 16 = 40 flops, cost = 32 // Verify contraction string format - EXPECT_EQ(result->accu_str, "(2,(0,1))"); + EXPECT_EQ(result->accu_str, "(2,(0,1))"); } TEST_F(SearchTreeTest, BasicSearchOrder2) { @@ -67,7 +67,7 @@ TEST_F(SearchTreeTest, BasicSearchOrder2) { // Create tensor 2 with shape [10,4,8] and labels ["j","k","m"] // This connects with t1 through j, t3 through k, and t4 through m - PseudoUniTensor t2(1); + PseudoUniTensor t2(1); t2.shape = {10, 4, 8}; t2.labels = {"j", "k", "m"}; t2.cost = 0; @@ -76,7 +76,7 @@ TEST_F(SearchTreeTest, BasicSearchOrder2) { // Create tensor 3 with shape [4,2] and labels ["k","l"] // This connects with t2 through k and t4 through l PseudoUniTensor t3(2); - t3.shape = {4, 2}; + t3.shape = {4, 2}; t3.labels = {"k", "l"}; t3.cost = 0; tensors.push_back(t3); From 3debb4ef827b68608fc3f61ed393d4b1c93284e0 Mon Sep 17 00:00:00 2001 From: yjkao Date: Fri, 13 Dec 2024 14:39:49 +0800 Subject: [PATCH 11/11] Remove debugging messages --- include/contraction_tree.hpp | 4 ++-- src/search_tree.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/contraction_tree.hpp b/include/contraction_tree.hpp index a0140005..95e9f9a9 100644 --- a/include/contraction_tree.hpp +++ b/include/contraction_tree.hpp @@ -70,13 +70,13 @@ namespace cytnx { auto self = shared_from_this(); if (left) { - std::cout << "Setting root for left child of " << name << std::endl; + //std::cout << "Setting root for left child of " << name << std::endl; left->root = self; left->set_root_ptrs(); } if (right) { - std::cout << "Setting root for right child of " << name << std::endl; + // std::cout << "Setting root for right child of " << name << std::endl; right->root = self; right->set_root_ptrs(); } diff --git a/src/search_tree.cpp b/src/search_tree.cpp index c4669e4c..c0e15675 100644 --- a/src/search_tree.cpp +++ b/src/search_tree.cpp @@ -215,7 +215,7 @@ namespace cytnx { } // Run optimal tree solver directly with base_nodes - root_ptr = OptimalTreeSolver::solve(base_nodes, true); + root_ptr = OptimalTreeSolver::solve(base_nodes, false); } PseudoUniTensor& PseudoUniTensor::operator=(const PseudoUniTensor& rhs) {