forked from jurajHasik/peps-torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtn_interface_abelian.py
74 lines (60 loc) · 2.49 KB
/
tn_interface_abelian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# def tensordot_complex(t1, t2, *args):
# return torch.tensordot(t1.real, t2.real, *args) \
# - torch.tensordot(t1.imag, t2.imag, *args) \
# + (torch.tensordot(t1.real, t2.imag, *args) \
# + torch.tensordot(t1.imag, t2.real, *args)) * 1.0j
# def contract(t1, t2, *args):
# if t1.is_complex() and t2.is_complex():
# return tensordot_complex(t1, t2, *args)
# elif not t1.is_complex() and not t2.is_complex():
# return torch.tensordot(t1, t2, *args)
# else:
# raise NotImplementedError(f"Tensors t1 {t1.dtype} and t2 {t2.dtype}"\
# +" are not either both complex or both real")
def contract(t1, t2, *args, **kwargs):
return t1.tensordot(t2, *args, **kwargs)
# def mm_complex(m1, m2):
# return torch.mm(m1.real, m2.real) - torch.mm(m1.imag, m2.imag) \
# + (torch.mm(m1.real, m2.imag) + torch.mm(m1.imag, m2.real)) * 1.0j
# def mm(m1, m2):
# if m1.is_complex() and m2.is_complex():
# return mm_complex(m1, m2)
# elif not m1.is_complex() and not m2.is_complex():
# return torch.mm(m1, m2)
# else:
# raise NotImplementedError(f"Tensors m1 {m1.dtype} and m2 {m2.dtype} "\
# +" are not either both complex or both real")
def mm(m1, m2, **kwargs):
assert m1.ndim==2, "m1 is not a matrix"
assert m2.ndim==2, "m2 is not a matrix"
return m1.tensordot(m2, ((1),(0)), **kwargs)
# def einsum_complex(op, *ts):
# if len(ts)!=2: raise NotImplementedError("einsum implementation limited to two tensors")
# return torch.einsum(op, ts[0].real, ts[1].real) \
# - torch.einsum(op, ts[0].imag, ts[1].imag) \
# + (torch.einsum(op, ts[0].real, ts[1].imag) \
# + torch.einsum(op, ts[0].imag, ts[1].real)) * 1.0j
# def einsum(op, *ts):
# assert isinstance(op, str), "invalid operation"
# if False not in [t.is_complex() for t in ts]:
# return einsum_complex(op, *ts)
# elif True not in [t.is_complex() for t in ts]:
# return torch.einsum(op, *ts)
# else:
# raise NotImplementedError(f"Tensors are not either all "\
# +"complex or all real")
# def view(t, *args):
# return t.view(*args)
# def permute(t, *args):
# return t.permute(*args)
def permute(t, *args):
return t.transpose(*args)
# def contiguous(t):
# return t.contiguous()
# def transpose(t):
# return torch.transpose(t, 0, 1)
def transpose(m):
assert m.ndim==2, "m is not a matrix"
return m.transpose((1,0))
def conj(t):
return t.conj()