From f519baa2a19b50aec7322e347961411a67069a6d Mon Sep 17 00:00:00 2001 From: Kai Date: Tue, 13 Aug 2024 01:11:45 -0400 Subject: [PATCH] finished contract --- src/cytnx_torch/linalg/matmul_dg.py | 26 +++++++++ .../unitensor/regular_unitensor.py | 58 ++++++++++++++++++- 2 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 src/cytnx_torch/linalg/matmul_dg.py diff --git a/src/cytnx_torch/linalg/matmul_dg.py b/src/cytnx_torch/linalg/matmul_dg.py new file mode 100644 index 0000000..5cc756c --- /dev/null +++ b/src/cytnx_torch/linalg/matmul_dg.py @@ -0,0 +1,26 @@ +import torch + + +def _tensordot_dg( + A: torch.Tensor, + B: torch.Tensor, + is_diag_left: bool = False, + is_diag_right: bool = False, +): + """ + Matrix multiplication of two tensors A and B. + + Args: + A (torch.Tensor): the first tensor + B (torch.Tensor): the second tensor + + Returns: + torch.Tensor: the result of the matrix multiplication + """ + + if is_diag_left and is_diag_right: + return A * B + + if is_diag_left: + # check shape: + pass diff --git a/src/cytnx_torch/unitensor/regular_unitensor.py b/src/cytnx_torch/unitensor/regular_unitensor.py index c6d8973..5d76483 100644 --- a/src/cytnx_torch/unitensor/regular_unitensor.py +++ b/src/cytnx_torch/unitensor/regular_unitensor.py @@ -1,7 +1,8 @@ -from dataclasses import dataclass, field from beartype.typing import List, Optional, Union, Tuple -import torch +from dataclasses import dataclass, field from numbers import Number +import string +import torch import numpy as np from ..bond import Bond, BondType from ..converter import RegularUniTensorConverter @@ -294,7 +295,58 @@ def contract( ) -> "RegularUniTensor": match rhs: case RegularUniTensor(): - raise NotImplementedError("TODO") + if self.is_diag or rhs.is_diag: + # TODO + raise ValueError( + "contract with diagonal tensor is under developement." + ) + + # TODO optimize this: + self_lbls = set(self.labels) + rhs_lbls = set(rhs.labels) + + # get common labels: + unique_labels = self_lbls | rhs_lbls + contracted_labels = self_lbls & rhs_lbls + mapper = {lbl: s for s, lbl in zip(string.ascii_letters, unique_labels)} + + lhs_str = "".join([mapper[lbl] for lbl in self.labels]) + rhs_str = "".join([mapper[lbl] for lbl in rhs.labels]) + + # keep the order of labels: + lhs_remain_idx = [ + i + for i, lbl in enumerate(self.labels) + if lbl not in contracted_labels + ] + rhs_remain_idx = [ + i + for i, lbl in enumerate(rhs.labels) + if lbl not in contracted_labels + ] + + lhs_remain_str = "".join( + [mapper[self.labels[i]] for i in lhs_remain_idx] + ) + rhs_remain_str = "".join( + [mapper[rhs.labels[i]] for i in rhs_remain_idx] + ) + res_str = f"{lhs_remain_str}{rhs_remain_str}" + + new_data = torch.einsum( + f"{lhs_str},{rhs_str}->{res_str}", self.data, rhs.data + ) + + # construct new ut: + return RegularUniTensor( + labels=[self.labels[i] for i in lhs_remain_idx] + + [rhs.labels[i] for i in rhs_remain_idx], + bonds=[self.bonds[i] for i in lhs_remain_idx] + + [rhs.bonds[i] for i in rhs_remain_idx], + backend_args=self.backend_args, + data=new_data, + ) + case RegularUniTensorConverter(): return rhs._contract(is_lhs=False, utensor=self) case _: