-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.py
94 lines (70 loc) · 2.48 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations
class Tensor:
"""Not really a tensor"""
counter = 1
def __init__(self,
value: float,
op = '') -> None:
self.value = value
self.parents = []
self.name = "Tensor_" + str(Tensor.counter)
self.grad = 0.
self.backward_fun = lambda: None
self.op = op # For displaying which op created the tensor
Tensor.counter += 1
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f"{self.name}: {self.value:.4f}"
def __add__(self, obj2) -> Tensor:
obj2 = obj2 if isinstance(obj2, Tensor) else Tensor(obj2)
output = Tensor(self.value + obj2.value, op="+")
output.parents.extend([self, obj2])
def backward():
self.grad += output.grad
obj2.grad += output.grad
output.backward_fun = backward
return output
def __mul__(self, obj2) -> Tensor:
obj2 = obj2 if isinstance(obj2, Tensor) else Tensor(obj2)
output = Tensor(self.value * obj2.value, op="*")
output.parents.extend([self, obj2])
def backward():
self.grad += obj2.value * output.grad
obj2.grad += self.value * output.grad
output.backward_fun = backward
return output
def relu(self,) -> Tensor:
output = Tensor(0 if self.value < 0 else self.value, op='ReLU')
output.parents.extend([self, ])
def backward():
self.grad += (output.value > 0) * output.grad
output.backward_fun = backward
return output
def backward(self) -> None:
"""Computes gradient from this tensor backwards"""
def _topsort(t):
visited = set()
output = []
def _run_topsort(t):
if t not in visited:
visited.add(t)
for parent in t.parents:
_run_topsort(parent)
output.append(t)
_run_topsort(t)
return output
topsort = _topsort(self)
self.grad = 1
for t in reversed(topsort):
t.backward_fun()
def __radd__(self, obj2):
return self + obj2
def __rmul__(self, obj2):
return self * obj2
def __sub__(self, obj2):
return self + (-obj2)
def __rsub__(self, obj2):
return obj2 + (-self)
def __neg__(self):
return self * -1