diff --git a/include/contraction_tree.hpp b/include/contraction_tree.hpp index 2c4f8522..95e9f9a9 100644 --- a/include/contraction_tree.hpp +++ b/include/contraction_tree.hpp @@ -8,97 +8,145 @@ #include #include #include +#include +#include // Add for debug output #ifdef BACKEND_TORCH #else namespace cytnx { /// @cond - class Node { + class Node : public std::enable_shared_from_this { public: - UniTensor utensor; // don't worry about copy, because everything are references in cytnx! + UniTensor utensor; bool is_assigned; - Node *left; - Node *right; + std::shared_ptr left; + std::shared_ptr right; + std::weak_ptr root; std::string name; - Node *root; - - Node() : is_assigned(false), left(nullptr), right(nullptr), root(nullptr){}; - Node(const Node &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->utensor = rhs.utensor; - this->is_assigned = rhs.is_assigned; + + Node() : is_assigned(false) {} + + Node(const Node& rhs) + : utensor(rhs.utensor), + is_assigned(rhs.is_assigned), + left(rhs.left), + right(rhs.right), + name(rhs.name) { + // Only copy root if it exists + if (auto r = rhs.root.lock()) { + root = r; + } } - Node &operator==(const Node &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->utensor = rhs.utensor; - this->is_assigned = rhs.is_assigned; + + Node& operator=(const Node& rhs) { + if (this != &rhs) { + utensor = rhs.utensor; + is_assigned = rhs.is_assigned; + left = rhs.left; + right = rhs.right; + name = rhs.name; + if (auto r = rhs.root.lock()) { + root = r; + } + } return *this; } - Node(Node *in_left, Node *in_right, const UniTensor &in_uten = UniTensor()) - : is_assigned(false), left(nullptr), right(nullptr), root(nullptr) { - this->left = in_left; - this->right = in_right; - in_left->root = this; - in_right->root = this; - if (in_uten.uten_type() != UTenType.Void) this->utensor = in_uten; + + Node(std::shared_ptr in_left, std::shared_ptr in_right, + const UniTensor& in_uten = UniTensor()) + : is_assigned(false), left(in_left), right(in_right) { + // Set name based on children + if (left && right) { + name = "(" + left->name + "," + right->name + ")"; + } + + if (in_uten.uten_type() != UTenType.Void) { + utensor = in_uten; + } } - void assign_utensor(const UniTensor &in_uten) { - this->utensor = in_uten; - this->is_assigned = true; + + void set_root_ptrs() { + try { + auto self = shared_from_this(); + + if (left) { + //std::cout << "Setting root for left child of " << name << std::endl; + left->root = self; + left->set_root_ptrs(); + } + + if (right) { + // std::cout << "Setting root for right child of " << name << std::endl; + right->root = self; + right->set_root_ptrs(); + } + } catch (const std::bad_weak_ptr& e) { + std::cerr << "Failed to set root ptrs for node " << name << ": " << e.what() << std::endl; + throw; + } } + void clear_utensor() { - this->is_assigned = false; - this->utensor = UniTensor(); + if (left) { + left->clear_utensor(); + left->root.reset(); + } + if (right) { + right->clear_utensor(); + right->root.reset(); + } + is_assigned = false; + utensor = UniTensor(); + } + + void assign_utensor(const UniTensor& in_uten) { + utensor = in_uten; + is_assigned = true; } }; class ContractionTree { public: - std::vector nodes_container; // this contains intermediate layer. - std::vector base_nodes; // this is the button layer. + std::vector> nodes_container; // intermediate layer + std::vector> base_nodes; // bottom layer - ContractionTree(){}; - ContractionTree(const ContractionTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - } - ContractionTree &operator==(const ContractionTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - return *this; - } + ContractionTree() = default; + ContractionTree(const ContractionTree&) = default; + ContractionTree& operator=(const ContractionTree&) = default; - // clear all the elements in the whole tree. void clear() { nodes_container.clear(); base_nodes.clear(); - // nodes_container.reserve(1024); } - // clear all the intermediate layer, leave all the base_nodes intact. - // and reset the root pointer on the base ondes + void reset_contraction_order() { - nodes_container.clear(); - for (cytnx_uint64 i = 0; i < base_nodes.size(); i++) { - base_nodes[i].root = nullptr; + // First clear all root pointers + for (auto& node : base_nodes) { + if (node) node->root.reset(); } - // nodes_container.reserve(1024); + // Then clear the container + nodes_container.clear(); } + void reset_nodes() { - // reset all nodes but keep the skeleton - for (cytnx_uint64 i = 0; i < this->nodes_container.size(); i++) { - this->nodes_container[i].clear_utensor(); + // Clear from root down if we have nodes + if (!nodes_container.empty() && nodes_container.back()) { + nodes_container.back()->clear_utensor(); } - for (cytnx_uint64 i = 0; i < this->base_nodes.size(); i++) { - this->base_nodes[i].clear_utensor(); + nodes_container.clear(); + + // Reset base nodes + for (auto& node : base_nodes) { + if (node) { + node->is_assigned = false; + node->utensor = UniTensor(); + } } } + void build_default_contraction_tree(); - void build_contraction_tree_by_tokens(const std::map &name2pos, - const std::vector &tokens); + void build_contraction_tree_by_tokens(const std::map& name2pos, + const std::vector& tokens); }; /// @endcond } // namespace cytnx diff --git a/include/search_tree.hpp b/include/search_tree.hpp index a042706b..82415b1d 100644 --- a/include/search_tree.hpp +++ b/include/search_tree.hpp @@ -7,98 +7,83 @@ #include #include #include +#include +#include +#include #include "UniTensor.hpp" -#ifdef BACKEND_TORCH -#else - namespace cytnx { - /// @cond - class PsudoUniTensor { + + using IndexSet = std::bitset<128>; + + class PseudoUniTensor { public: - // UniTensor utensor; //don't worry about copy, because everything are references in cytnx! + bool isLeaf; + + // Leaf node data std::vector labels; std::vector shape; bool is_assigned; - PsudoUniTensor *left; - PsudoUniTensor *right; - PsudoUniTensor *root; + cytnx_uint64 tensorIndex; + + // Internal node data + std::unique_ptr left; + std::unique_ptr right; + cytnx_float cost; cytnx_uint64 ID; - std::string accu_str; - PsudoUniTensor() - : is_assigned(false), left(nullptr), right(nullptr), root(nullptr), cost(0), ID(0){}; - PsudoUniTensor(const PsudoUniTensor &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->labels = rhs.labels; - this->shape = rhs.shape; - this->is_assigned = rhs.is_assigned; - this->cost = rhs.cost; - this->accu_str = rhs.accu_str; - this->ID = rhs.ID; - } - PsudoUniTensor &operator==(const PsudoUniTensor &rhs) { - this->left = rhs.left; - this->right = rhs.right; - this->root = rhs.root; - this->labels = rhs.labels; - this->shape = rhs.shape; - this->is_assigned = rhs.is_assigned; - this->cost = rhs.cost; - this->accu_str = rhs.accu_str; - this->ID = rhs.ID; - return *this; - } - void from_utensor(const UniTensor &in_uten) { - this->labels = in_uten.labels(); - this->shape = in_uten.shape(); - this->is_assigned = true; - } - void clear_utensor() { - this->is_assigned = false; - this->labels.clear(); - this->shape.clear(); - this->ID = 0; - this->cost = 0; - this->accu_str = ""; - } - void set_ID(const cytnx_int64 &ID) { this->ID = ID; } + // Constructors + explicit PseudoUniTensor(cytnx_uint64 index = 0) + : isLeaf(true), + tensorIndex(index), + is_assigned(false), + cost(0), + ID(1ULL << index), + accu_str(std::to_string(index)) {} + + PseudoUniTensor(std::unique_ptr l, std::unique_ptr r) + : isLeaf(false), left(std::move(l)), right(std::move(r)), cost(0), ID(0) {} + + // Copy and move constructors and assignment operators + PseudoUniTensor(const PseudoUniTensor& rhs); + PseudoUniTensor(PseudoUniTensor&& rhs) noexcept; + PseudoUniTensor& operator=(const PseudoUniTensor& rhs); + PseudoUniTensor& operator=(PseudoUniTensor&& rhs) noexcept; + ~PseudoUniTensor() = default; + + void from_utensor(const UniTensor& in_uten); + void clear_utensor(); }; + namespace OptimalTreeSolver { + std::unique_ptr solve(const std::vector& tensors, + bool verbose = false); + } + class SearchTree { public: - std::vector> nodes_container; - // std::vector nodes_container; // this contains intermediate layer. - std::vector base_nodes; // this is the button layer. - - SearchTree(){}; - SearchTree(const SearchTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - } - SearchTree &operator==(const SearchTree &rhs) { - this->nodes_container = rhs.nodes_container; - this->base_nodes = rhs.base_nodes; - return *this; - } + std::vector base_nodes; - // clear all the elements in the whole tree. + SearchTree() = default; void clear() { - nodes_container.clear(); + root_ptr.reset(); base_nodes.clear(); - // nodes_container.reserve(1024); } - // clear all the intermediate layer, leave all the base_nodes intact. - // and reset the root pointer on the base ondes - void reset_search_order() { nodes_container.clear(); } + void reset_search_order() { root_ptr.reset(); } void search_order(); + + std::vector> get_root() const { return {{root_ptr.get()}}; } + + private: + std::unique_ptr root_ptr; }; - /// @endcond + + // Helper functions declarations + cytnx_float get_cost(const PseudoUniTensor& t1, const PseudoUniTensor& t2); + PseudoUniTensor pContract(PseudoUniTensor& t1, PseudoUniTensor& t2); + } // namespace cytnx -#endif #endif // CYTNX_SEARCH_TREE_H_ diff --git a/src/RegularGncon.cpp b/src/RegularGncon.cpp index 9a95ba8b..8060a95e 100644 --- a/src/RegularGncon.cpp +++ b/src/RegularGncon.cpp @@ -299,6 +299,17 @@ namespace cytnx { // // put tensor: // for (int i = 0; i < utensors.size(); i++) this->tensors[i] = utensors[i]; + + // Update node creation + this->tensors.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create nodes using make_shared + for (size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } } void RegularGncon::FromString(const std::vector &contents) { @@ -597,14 +608,15 @@ namespace cytnx { string RegularGncon::getOptimalOrder() { // Creat a SearchTree to search for optim contraction order. SearchTree Stree; + Stree.base_nodes.clear(); Stree.base_nodes.resize(this->tensors.size()); + for (cytnx_uint64 t = 0; t < this->tensors.size(); t++) { - // Stree.base_nodes[t].from_utensor(this->tensors[t]); //create psudotensors from base tensors - Stree.base_nodes[t].from_utensor(CtTree.base_nodes[t].utensor); + Stree.base_nodes[t].from_utensor(this->tensors[t]); Stree.base_nodes[t].accu_str = this->names[t]; } Stree.search_order(); - return Stree.nodes_container.back()[0].accu_str; + return Stree.get_root().back()[0]->accu_str; } UniTensor RegularGncon::Launch(const bool &optimal, const string &contract_order /*default ""*/) { @@ -636,10 +648,10 @@ namespace cytnx { // modify the label of unitensor (shared): // this->tensors[idx].set_labels(this->label_arr[idx]);//this conflict - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict // this->CtTree.base_nodes[idx].name = this->tensors[idx].name(); - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; // cout << this->tensors[idx].name() << " " << idx << "from dict:" << // this->name2pos[this->tensors[idx].name()] << endl; @@ -667,8 +679,8 @@ namespace cytnx { // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = this->CtTree.nodes_container.back(); int ly = 0; bool ict; @@ -720,7 +732,7 @@ namespace cytnx { } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // std::cout << out << std::endl; // out.print_diagram(); @@ -734,7 +746,7 @@ namespace cytnx { // 6. permute accroding to pre-set labels: if (TOUT_labels.size()) { - out.permute_(TOUT_labels, TOUT_iBondNum, true); + out.permute_(TOUT_labels, TOUT_iBondNum); } // UniTensor out; diff --git a/src/RegularNetwork.cpp b/src/RegularNetwork.cpp index e8c8e1fb..049b6b8e 100644 --- a/src/RegularNetwork.cpp +++ b/src/RegularNetwork.cpp @@ -230,8 +230,8 @@ namespace cytnx { vector> CtTree_to_eisumpath(ContractionTree CtTree, vector tns) { vector> path; - stack stk; - Node *root = &(CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = CtTree.nodes_container.back(); int ly = 0; bool ict; do { @@ -536,6 +536,14 @@ namespace cytnx { this->tensors.resize(this->names.size()); this->CtTree.base_nodes.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create base nodes properly + for (size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } // checking if all TN are set in ORDER. if (isORDER_exist) { @@ -632,7 +640,7 @@ namespace cytnx { vector names; for (int i = 0; i < this->names.size(); i++) { names.push_back(this->names[i]); - CtTree.base_nodes[i].name = this->names[i]; + CtTree.base_nodes[i]->name = this->names[i]; } if (ORDER_tokens.size() != 0) { CtTree.build_contraction_tree_by_tokens(this->name2pos, ORDER_tokens); @@ -640,7 +648,7 @@ namespace cytnx { CtTree.build_default_contraction_tree(); } this->einsum_path = CtTree_to_eisumpath(CtTree, names); - } + } // end of FromString void RegularNetwork::Fromfile(const string &fname) { const cytnx_uint64 MAXLINES = 1024; @@ -943,7 +951,7 @@ namespace cytnx { Stree.base_nodes[t].accu_str = this->names[t]; } Stree.search_order(); - return Stree.nodes_container.back()[0].accu_str; + return Stree.get_root().back()[0]->accu_str; } UniTensor RegularNetwork::Launch() { @@ -955,9 +963,9 @@ namespace cytnx { // cpu workflow for (cytnx_uint64 idx = 0; idx < this->tensors.size(); idx++) { - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; } // 1.5 contraction order: if (ORDER_tokens.size() != 0) { @@ -969,59 +977,46 @@ namespace cytnx { // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = this->CtTree.nodes_container.back(); + root->set_root_ptrs(); // Add this line int ly = 0; bool ict; do { - // move the lmost - while ((root != nullptr)) { - if (root->right != nullptr) stk.push(root->right); + // move the leftmost + while (root != nullptr) { + if (root->right) stk.push(root->right); stk.push(root); root = root->left; } root = stk.top(); stk.pop(); - // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); + ict = true; - if ((root->right != nullptr) && !stk.empty()) { - if (stk.top() == root->right) { + if (root->right && !stk.empty()) { + if (stk.top() == root->right) { // This comparison now works with shared_ptr stk.pop(); stk.push(root); root = root->right; ict = false; } } - if (ict) { - // process! - // cout << "OK" << endl; - if ((root->right != nullptr) && (root->left != nullptr)) { - // cout << "L,R::\n"; - // root->left->utensor.print_diagram(1); - // root->right->utensor.print_diagram(1); + if (ict) { + if (root->right && root->left) { root->utensor = Contract(root->left->utensor, root->right->utensor); - // root->left->utensor.print_diagram(); root->right->utensor.print_diagram(); - // root->utensor.print_diagram(); root->utensor.set_name(root->left->utensor.name() + - // root->right->utensor.name()); - root->left->clear_utensor(); // remove intermediate unitensor to save heap space - root->right->clear_utensor(); // remove intermediate unitensor to save heap space + root->left->clear_utensor(); + root->right->clear_utensor(); root->is_assigned = true; - // cout << "contract!" << endl; } - root = nullptr; } - - // cout.flush(); - // break; - } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // cout << out << endl; // out.print_diagram(); @@ -1087,9 +1082,9 @@ namespace cytnx { } #else for (cytnx_uint64 idx = 0; idx < this->tensors.size(); idx++) { - this->CtTree.base_nodes[idx].utensor = + this->CtTree.base_nodes[idx]->utensor = this->tensors[idx].relabels(this->label_arr[idx]); // this conflict - this->CtTree.base_nodes[idx].is_assigned = true; + this->CtTree.base_nodes[idx]->is_assigned = true; } // 1.5 contraction order: if (ORDER_tokens.size() != 0) { @@ -1100,36 +1095,35 @@ namespace cytnx { } // 2. contract using postorder traversal: // cout << this->CtTree.nodes_container.size() << endl; - stack stk; - Node *root = &(this->CtTree.nodes_container.back()); + stack> stk; + std::shared_ptr root = std::make_shared(this->CtTree.nodes_container.back()); + root->set_root_ptrs(); // Add this line int ly = 0; bool ict; do { - // move the lmost - while ((root != nullptr)) { - if (root->right != nullptr) stk.push(root->right); + // move the leftmost + while (root != nullptr) { + if (root->right) stk.push(root->right); stk.push(root); root = root->left; } root = stk.top(); stk.pop(); - // cytnx_error_msg(stk.size()==0,"[eRROR]","\n"); + ict = true; - if ((root->right != nullptr) && !stk.empty()) { - if (stk.top() == root->right) { + if (root->right && !stk.empty()) { + if (stk.top() == root->right) { // This comparison now works with shared_ptr stk.pop(); stk.push(root); root = root->right; ict = false; } } - if (ict) { - // process! - // cout << "OK" << endl; - if ((root->right != nullptr) && (root->left != nullptr)) { + if (ict) { + if (root->right && root->left) { root->utensor = Contract(root->left->utensor, root->right->utensor); root->left->clear_utensor(); // remove intermediate unitensor to save heap space root->right->clear_utensor(); // remove intermediate unitensor to save heap space @@ -1139,7 +1133,7 @@ namespace cytnx { } } while (!stk.empty()); // 3. get result: - UniTensor out = this->CtTree.nodes_container.back().utensor; + UniTensor out = this->CtTree.nodes_container.back()->utensor; // 4. reset nodes: this->CtTree.reset_nodes(); // 6. permute accroding to pre-set labels: @@ -1201,7 +1195,14 @@ namespace cytnx { "\n"); this->tensors.resize(this->names.size()); - this->CtTree.base_nodes.resize(this->names.size()); + this->CtTree.base_nodes.clear(); + + // Create nodes using make_shared + for (size_t i = 0; i < this->names.size(); i++) { + auto node = std::make_shared(); + node->name = this->names[i]; + this->CtTree.base_nodes.push_back(node); + } // checking label matching: map labelcnt; @@ -1282,7 +1283,7 @@ namespace cytnx { vector names; for (int i = 0; i < this->names.size(); i++) { names.push_back(this->names[i]); - CtTree.base_nodes[i].name = this->names[i]; + CtTree.base_nodes[i]->name = this->names[i]; } if (ORDER_tokens.size() != 0) { CtTree.build_contraction_tree_by_tokens(this->name2pos, ORDER_tokens); diff --git a/src/contraction_tree.cpp b/src/contraction_tree.cpp index fc565d1d..e4c3667b 100644 --- a/src/contraction_tree.cpp +++ b/src/contraction_tree.cpp @@ -10,22 +10,31 @@ using namespace std; namespace cytnx { void ContractionTree::build_default_contraction_tree() { this->reset_contraction_order(); - cytnx_error_msg(this->base_nodes.size() < 2, - "[ERROR][ContractionTree][build_default_contraction_order] contraction tree " - "should contain >=2 tensors in order to build contraction order.%s", + + cytnx_error_msg(this->base_nodes.size() < 2, "[ERROR] Need at least 2 tensors for contraction", "\n"); - Node *left = &(this->base_nodes[0]); - Node *right; - this->nodes_container.reserve( - this->base_nodes.size()); // reserve a contiguous memeory address to prevent re-allocate that - // change address. + std::shared_ptr left = this->base_nodes[0]; + std::shared_ptr right; + + this->nodes_container.reserve(this->base_nodes.size()); + for (cytnx_uint64 i = 1; i < this->base_nodes.size(); i++) { - right = &(this->base_nodes[i]); - this->nodes_container.push_back(Node(left, right)); - left = &(this->nodes_container.back()); + right = this->base_nodes[i]; + + auto new_node = std::make_shared(left, right); + + this->nodes_container.push_back(new_node); + left = new_node; + } + + if (!nodes_container.empty()) { + auto root = nodes_container.back(); + std::cout << "Setting root pointers from " << root->name << std::endl; + root->set_root_ptrs(); } } + void ContractionTree::build_contraction_tree_by_tokens( const std::map &name2pos, const std::vector &tokens) { this->reset_contraction_order(); @@ -38,9 +47,8 @@ namespace cytnx { "[ERROR][ContractionTree][build_contraction_order_by_tokens] cannot have empty tokens.%s", "\n"); - stack stk; - Node *left; - Node *right; + stack> stk; + std::shared_ptr left, right; stack operators; char topc; size_t pos = 0; @@ -68,10 +76,9 @@ namespace cytnx { stk.pop(); left = stk.top(); stk.pop(); - this->nodes_container.push_back(Node(left, right)); - // cout << right->name << " " << left->name <<">"; - // this->nodes_container.back().name = right->name + left->name; - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); if (!operators.empty()) topc = operators.top(); else @@ -90,10 +97,9 @@ namespace cytnx { stk.pop(); left = stk.top(); stk.pop(); - this->nodes_container.push_back(Node(left, right)); - // cout << right->name << " " << left->name << ">"; - // this->nodes_container.back().name = right->name + left->name; - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); if (!operators.empty()) topc = operators.top(); else @@ -112,8 +118,7 @@ namespace cytnx { "contain invalid TN name: %s ,which is not previously defined. \n", tok.c_str()); } - stk.push(&this->base_nodes[idx]); - // cout << "TN" << this->base_nodes[idx].name << endl; + stk.push(this->base_nodes[idx]); } } // for each token @@ -125,8 +130,9 @@ namespace cytnx { left = stk.top(); stk.pop(); // this->nodes_container.back().name = right->name + left->name; - this->nodes_container.push_back(Node(left, right)); - stk.push(&this->nodes_container.back()); + auto new_node = std::make_shared(left, right); + this->nodes_container.push_back(new_node); + stk.push(this->nodes_container.back()); } /* cout << "============" << endl; diff --git a/src/search_tree.cpp b/src/search_tree.cpp index a3cc9817..c0e15675 100644 --- a/src/search_tree.cpp +++ b/src/search_tree.cpp @@ -7,8 +7,8 @@ using namespace std; #else namespace cytnx { - - cytnx_float get_cost(const PsudoUniTensor &t1, const PsudoUniTensor &t2) { + // helper functions + cytnx_float get_cost(const PseudoUniTensor& t1, const PseudoUniTensor& t2) { cytnx_float cost = 1; vector shape1 = t1.shape; vector shape2 = t2.shape; @@ -30,88 +30,271 @@ namespace cytnx { return cost + t1.cost + t2.cost; } - PsudoUniTensor pContract(PsudoUniTensor &t1, PsudoUniTensor &t2) { - PsudoUniTensor t3; - t3.ID = t1.ID ^ t2.ID; - t3.cost = get_cost(t1, t2); + PseudoUniTensor pContract(PseudoUniTensor& t1, PseudoUniTensor& t2) { + PseudoUniTensor t3(0); // Initialize with index 0 + + t3.ID = t1.ID ^ t2.ID; // XOR of IDs to track contracted tensors + t3.cost = get_cost(t1, t2); // Calculate contraction cost + + // Find common labels between t1 and t2 vector loc1, loc2; vector comm_lbl; vec_intersect_(comm_lbl, t1.labels, t2.labels, loc1, loc2); + + // New shape is concatenation of non-contracted dimensions t3.shape = vec_concatenate(vec_erase(t1.shape, loc1), vec_erase(t2.shape, loc2)); + + // New labels are concatenation of non-contracted labels t3.labels = vec_concatenate(vec_erase(t1.labels, loc1), vec_erase(t2.labels, loc2)); + + // Set accumulation string using the original accu_str if available + if (t1.accu_str.empty()) t1.accu_str = std::to_string(t1.tensorIndex); + if (t2.accu_str.empty()) t2.accu_str = std::to_string(t2.tensorIndex); t3.accu_str = "(" + t1.accu_str + "," + t2.accu_str + ")"; + + // Set as internal node + t3.isLeaf = false; + t3.left = std::make_unique(t1); + t3.right = std::make_unique(t2); + return t3; } + namespace OptimalTreeSolver { + // Helper function to find connected components using DFS + void dfs(size_t node, const std::vector& adjacencyMatrix, IndexSet& visited, + std::vector& component) { + visited.set(node); + component.push_back(node); + + for (size_t i = 0; i < adjacencyMatrix.size(); ++i) { + if (adjacencyMatrix[node].test(i) && !visited.test(i)) { + dfs(i, adjacencyMatrix, visited, component); + } + } + } + + // Find connected components in the tensor network + std::vector> findConnectedComponents( + const std::vector& adjacencyMatrix) { + std::vector> components; + IndexSet visited; + + for (size_t i = 0; i < adjacencyMatrix.size(); ++i) { + if (!visited.test(i)) { + std::vector component; + dfs(i, adjacencyMatrix, visited, component); + components.push_back(component); + } + } + + return components; + } + + std::unique_ptr solve(const std::vector& tensors, + bool verbose) { + if (tensors.empty()) { + return nullptr; + } + + // Initialize nodes with copies of input tensors + std::vector> nodes; + nodes.reserve(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + auto node = std::make_unique(i); + *node = tensors[i]; + node->ID = 1ULL << i; + nodes.push_back(std::move(node)); + } + + const size_t n = nodes.size(); + // Build adjacency matrix with proper size + std::vector adjacencyMatrix(n); + + // Fill adjacency matrix + for (size_t i = 0; i < n; ++i) { + for (size_t j = i + 1; j < n; ++j) { + // Find common labels + vector common_lbl; + vector comm_idx1, comm_idx2; + vec_intersect_(common_lbl, nodes[i]->labels, nodes[j]->labels, comm_idx1, comm_idx2); + + if (!common_lbl.empty()) { + adjacencyMatrix[i].set(j); + adjacencyMatrix[j].set(i); + } + } + } + + // Find connected components + auto components = findConnectedComponents(adjacencyMatrix); + if (verbose && components.size() > 1) { + std::cout << "Found " << components.size() << " disconnected components" << std::endl; + } + + // Process each component separately + std::vector> component_results; + for (const auto& component : components) { + // Extract nodes for this component + std::vector> component_nodes; + std::vector remaining_indices = component; + + while (remaining_indices.size() > 1) { + // Find best contraction pair within component + size_t best_i = 0, best_j = 1; + cytnx_float min_cost = std::numeric_limits::max(); + + for (size_t ii = 0; ii < remaining_indices.size(); ++ii) { + size_t i = remaining_indices.at(ii); + for (size_t jj = ii + 1; jj < remaining_indices.size(); ++jj) { + size_t j = remaining_indices.at(jj); + if (adjacencyMatrix[i].test(j)) { + cytnx_float cost = get_cost(*nodes[i], *nodes[j]); + if (cost < min_cost) { + min_cost = cost; + best_i = ii; + best_j = jj; + } + } + } + } + + if (verbose) { + std::cout << "Contracting nodes " << remaining_indices[best_i] << " and " + << remaining_indices[best_j] << " with cost " << min_cost << std::endl; + } + + // Contract best pair + auto left = std::move(nodes[remaining_indices[best_i]]); + auto right = std::move(nodes[remaining_indices[best_j]]); + auto result = pContract(*left, *right); + auto result_ptr = std::make_unique(std::move(result)); + + // Update remaining indices + remaining_indices.erase(remaining_indices.begin() + best_j); + remaining_indices.erase(remaining_indices.begin() + best_i); + + // Store result in original nodes vector + size_t new_idx = nodes.size(); + nodes.push_back(std::move(result_ptr)); + remaining_indices.push_back(new_idx); + } + + // Store the component result + component_results.push_back(std::move(nodes[remaining_indices[0]])); + } + + // If there were multiple components, combine them + while (component_results.size() > 1) { + // Create new node for combining components + auto new_node = std::make_unique(); + new_node->isLeaf = false; + + // Move the first two components as children + new_node->left = std::move(component_results[0]); + new_node->right = std::move(component_results[1]); + + // Calculate cost and set properties + new_node->cost = get_cost(*new_node->left, *new_node->right); + new_node->accu_str = "(" + new_node->left->accu_str + "," + new_node->right->accu_str + ")"; + new_node->ID = new_node->left->ID ^ new_node->right->ID; + + // Update component list + component_results.erase(component_results.begin(), component_results.begin() + 2); + component_results.insert(component_results.begin(), std::move(new_node)); + } + + return std::move(component_results[0]); + } + } // namespace OptimalTreeSolver + void SearchTree::search_order() { this->reset_search_order(); - if (this->base_nodes.size() == 0) { - cytnx_error_msg(true, "[ERROR][SearchTree] no base node exist.%s", "\n"); + if (this->base_nodes.size() == 1 || this->base_nodes.size() == 0) { + cytnx_error_msg(true, "[ERROR][SearchTree] need at least 2 nodes.%s", "\n"); } - cytnx_int64 pid = 0; - this->nodes_container.resize(this->base_nodes.size()); - //[Regiving each base nodes it's own ID]: - for (cytnx_uint64 i = 0; i < this->base_nodes.size(); i++) { - this->base_nodes[i].set_ID(pow(2, i)); - this->nodes_container[i].reserve(this->base_nodes.size() * 2); // try - } + // Run optimal tree solver directly with base_nodes + root_ptr = OptimalTreeSolver::solve(base_nodes, false); + } - // init first layer - for (cytnx_uint64 t = 0; t < this->base_nodes.size(); t++) { - this->nodes_container[0].push_back(this->base_nodes[t]); + PseudoUniTensor& PseudoUniTensor::operator=(const PseudoUniTensor& rhs) { + if (this != &rhs) { + isLeaf = rhs.isLeaf; + labels = rhs.labels; + shape = rhs.shape; + is_assigned = rhs.is_assigned; + tensorIndex = rhs.tensorIndex; + cost = rhs.cost; + ID = rhs.ID; + accu_str = rhs.accu_str; + + if (!isLeaf) { + if (rhs.left) + left = std::make_unique(*rhs.left); + else + left = nullptr; + + if (rhs.right) + right = std::make_unique(*rhs.right); + else + right = nullptr; + } } + return *this; + } - bool secondtimescan = 0; - while (this->nodes_container.back().size() == - 0) { // I can't see the need of this while loop before using secondtimescan - for (int c = 1; c < this->base_nodes.size(); c++) { - for (int d1 = 0; d1 < (c + 1) / 2; d1++) { - int d2 = c - d1 - 1; - int n1 = this->nodes_container[d1].size(); - int n2 = this->nodes_container[d2].size(); - for (int i1 = 0; i1 < n1; i1++) { - int i2_start = (d1 == d2) ? i1 + 1 : 0; - for (int i2 = i2_start; i2 < n2; i2++) { - PsudoUniTensor &t1 = this->nodes_container[d1][i1]; - PsudoUniTensor &t2 = this->nodes_container[d2][i2]; - - // No common labels - // If it's the secondtimescan, that's probably because there're need of Kron - // operations. - if (!secondtimescan and cytnx::vec_intersect(t1.labels, t2.labels).size() == 0) - continue; - // overlap - if ((t1.ID & t2.ID) > 0) continue; - - PsudoUniTensor t3 = pContract(t1, t2); - bool exist = false; - for (int i = 0; i < nodes_container[c].size(); i++) { - if (t3.ID == nodes_container[c][i].ID) { - exist = true; - if (t3.cost < nodes_container[c][i].cost) { - nodes_container[c][i] = t3; - t1.root = &nodes_container[c][i]; - t2.root = &nodes_container[c][i]; - } - break; - } - } // i + PseudoUniTensor::PseudoUniTensor(const PseudoUniTensor& rhs) + : isLeaf(rhs.isLeaf), + labels(rhs.labels), + shape(rhs.shape), + is_assigned(rhs.is_assigned), + tensorIndex(rhs.tensorIndex), + cost(rhs.cost), + ID(rhs.ID), + accu_str(rhs.accu_str) { + if (!isLeaf) { + if (rhs.left) left = std::make_unique(*rhs.left); + if (rhs.right) right = std::make_unique(*rhs.right); + } + } - if (!exist) { - nodes_container[c].push_back(t3); - t1.root = &nodes_container[c].back(); - t2.root = &nodes_container[c].back(); - } + cytnx::PseudoUniTensor::PseudoUniTensor(PseudoUniTensor&& rhs) noexcept + : isLeaf(rhs.isLeaf), + labels(std::move(rhs.labels)), + shape(std::move(rhs.shape)), + is_assigned(rhs.is_assigned), + tensorIndex(rhs.tensorIndex), + left(std::move(rhs.left)), + right(std::move(rhs.right)), + cost(rhs.cost), + ID(rhs.ID), + accu_str(std::move(rhs.accu_str)) {} - } // for i2 - } // for i1 - } // for d1 - } // for c - secondtimescan = 1; - } // while + PseudoUniTensor& PseudoUniTensor::operator=(PseudoUniTensor&& rhs) noexcept { + if (this != &rhs) { + isLeaf = rhs.isLeaf; + labels = std::move(rhs.labels); + shape = std::move(rhs.shape); + is_assigned = rhs.is_assigned; + tensorIndex = rhs.tensorIndex; + left = std::move(rhs.left); + right = std::move(rhs.right); + cost = rhs.cost; + ID = rhs.ID; + accu_str = std::move(rhs.accu_str); + } + return *this; + } - // cout << nodes_container.back()[0].accu_str << endl; + void PseudoUniTensor::from_utensor(const UniTensor& in_uten) { + isLeaf = true; + labels = in_uten.labels(); + shape = in_uten.shape(); + is_assigned = true; + // Other members keep their default/current values + left = nullptr; + right = nullptr; } } // namespace cytnx diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7ebb0686..46ce1397 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable( Accessor_test.cpp Tensor_test.cpp Storage_test.cpp + search_tree_test.cpp utils_test/vec_concatenate.cpp utils_test/vec_unique.cpp utils/getNconParameter.cpp @@ -55,6 +56,10 @@ target_link_libraries( gtest ) target_link_libraries(test_main cytnx) + +add_compile_options(-fsanitize=address) +add_link_options(-fsanitize=address) + #target_link_libraries(test_main PUBLIC "-lgcov --coverage") include(GoogleTest) gtest_discover_tests(test_main diff --git a/tests/Network_test.cpp b/tests/Network_test.cpp index a13763c9..dab32bfa 100644 --- a/tests/Network_test.cpp +++ b/tests/Network_test.cpp @@ -20,7 +20,11 @@ TEST_F(NetworkTest, Network_dense_FromString) { auto net = Network(); - net.FromString({"A: a,b,c", "B: c,d", "C: d,e", "ORDER:(A,(B,C))", "TOUT: a,b;e"}); + + std::vector network_def = {"A: a,b,c", "B: c,d", "C: d,e", "ORDER:(A,(B,C))", + "TOUT: a,b;e"}; + + net.FromString(network_def); } TEST_F(NetworkTest, Network_dense_no_order) { diff --git a/tests/search_tree_test.cpp b/tests/search_tree_test.cpp new file mode 100644 index 00000000..6133fc15 --- /dev/null +++ b/tests/search_tree_test.cpp @@ -0,0 +1,156 @@ +#include "cytnx.hpp" +#include "search_tree.hpp" +#include + +using namespace cytnx; +using namespace std; + +class SearchTreeTest : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(SearchTreeTest, BasicSearchOrder) { + // Create a simple network of 3 tensors + vector tensors; + + // Create tensor 1 with shape [2,3] and labels ["i","j"] + PseudoUniTensor t1(0); + t1.shape = {2, 3}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tensors.push_back(t1); + + // Create tensor 2 with shape [3,4] and labels ["j","k"] + PseudoUniTensor t2(1); + t2.shape = {3, 4}; + t2.labels = {"j", "k"}; + t2.cost = 0; + tensors.push_back(t2); + + // Create tensor 3 with shape [4,2] and labels ["k","i"] + PseudoUniTensor t3(2); + t3.shape = {4, 2}; + t3.labels = {"k", "i"}; + t3.cost = 0; + tensors.push_back(t3); + + // Create search tree + SearchTree tree; + tree.base_nodes = tensors; + + // Find optimal contraction order + tree.search_order(); + + // Get result node + auto result = tree.get_root().back()[0]; + + // Verify final cost is optimal + EXPECT_EQ(result->cost, 32); // 2*3*4 + 2*2*4 = 24 + 16 = 40 flops, cost = 32 + + // Verify contraction string format + EXPECT_EQ(result->accu_str, "(2,(0,1))"); +} + +TEST_F(SearchTreeTest, BasicSearchOrder2) { + // Create a network of 4 tensors to test more complex contraction ordering + vector tensors; + + // Create tensor 1 with shape [2,10] and labels ["i","j"] + // This will connect with tensor 2 through index j + PseudoUniTensor t1(0); + t1.shape = {2, 10}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tensors.push_back(t1); + + // Create tensor 2 with shape [10,4,8] and labels ["j","k","m"] + // This connects with t1 through j, t3 through k, and t4 through m + PseudoUniTensor t2(1); + t2.shape = {10, 4, 8}; + t2.labels = {"j", "k", "m"}; + t2.cost = 0; + tensors.push_back(t2); + + // Create tensor 3 with shape [4,2] and labels ["k","l"] + // This connects with t2 through k and t4 through l + PseudoUniTensor t3(2); + t3.shape = {4, 2}; + t3.labels = {"k", "l"}; + t3.cost = 0; + tensors.push_back(t3); + + // Create tensor 4 with shape [2,8,7] and labels ["l","m","n"] + // This connects with t3 through l and t2 through m + // The n index remains uncontracted as an external index + PseudoUniTensor t4(3); + t4.shape = {2, 8, 7}; + t4.labels = {"l", "m", "n"}; + t4.cost = 0; + tensors.push_back(t4); + + // Create search tree and set base nodes + SearchTree tree; + tree.base_nodes = tensors; + + // Find optimal contraction order using search algorithm + tree.search_order(); + + // Get the final contracted result node + auto result = tree.get_root().back()[0]; + + // Verify the final contraction cost is optimal + // The optimal sequence contracts (t1,t2) first, then t3, then t4 + EXPECT_EQ(result->cost, 1536); + + // Verify the contraction sequence string matches expected optimal order + // Format is (tensor_id,(tensor_id,tensor_id)) showing order of pairwise contractions + cout << result->accu_str << endl; + EXPECT_EQ(result->accu_str, "((2,3),(0,1))"); +} + +TEST_F(SearchTreeTest, EmptyTree) { + SearchTree tree; + + // Should throw error when searching empty tree + EXPECT_THROW(tree.search_order(), std::logic_error); +} + +TEST_F(SearchTreeTest, SingleNode) { + SearchTree tree; + + // Add single node + PseudoUniTensor t1(0); + t1.shape = {2, 2}; + t1.labels = {"i", "i"}; + t1.cost = 0; + tree.base_nodes.push_back(t1); + + // Should throw error - need at least 2 nodes + EXPECT_THROW(tree.search_order(), std::logic_error); +} + +TEST_F(SearchTreeTest, DisconnectedNetwork) { + SearchTree tree; + + // Create two tensors with no common indices + PseudoUniTensor t1(0); + t1.shape = {2, 3}; + t1.labels = {"i", "j"}; + t1.cost = 0; + tree.base_nodes.push_back(t1); + + PseudoUniTensor t2(1); + t2.shape = {4, 5}; + t2.labels = {"k", "l"}; + t2.cost = 0; + tree.base_nodes.push_back(t2); + + // Should still work but with higher cost due to direct product + tree.search_order(); + auto result = tree.get_root().back()[0]; + + EXPECT_EQ(result->cost, 120); // 2*3*4*5 = 120 + EXPECT_EQ(result->accu_str, "(0,1)"); +}