-
-
Notifications
You must be signed in to change notification settings - Fork 292
/
step17.py
executable file
·100 lines (74 loc) · 2.37 KB
/
step17.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import weakref
import numpy as np
class Variable:
def __init__(self, data):
if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError('{} is not supported'.format(type(data)))
self.data = data
self.grad = None
self.creator = None
self.generation = 0
def set_creator(self, func):
self.creator = func
self.generation = func.generation + 1
def cleargrad(self):
self.grad = None
def backward(self):
if self.grad is None:
self.grad = np.ones_like(self.data)
funcs = []
seen_set = set()
def add_func(f):
if f not in seen_set:
funcs.append(f)
seen_set.add(f)
funcs.sort(key=lambda x: x.generation)
add_func(self.creator)
while funcs:
f = funcs.pop()
gys = [output().grad for output in f.outputs] # output is weakref
gxs = f.backward(*gys)
if not isinstance(gxs, tuple):
gxs = (gxs,)
for x, gx in zip(f.inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx
if x.creator is not None:
add_func(x.creator)
def as_array(x):
if np.isscalar(x):
return np.array(x)
return x
class Function:
def __call__(self, *inputs):
xs = [x.data for x in inputs]
ys = self.forward(*xs)
if not isinstance(ys, tuple):
ys = (ys,)
outputs = [Variable(as_array(y)) for y in ys]
self.generation = max([x.generation for x in inputs])
for output in outputs:
output.set_creator(self)
self.inputs = inputs
self.outputs = [weakref.ref(output) for output in outputs]
return outputs if len(outputs) > 1 else outputs[0]
def forward(self, xs):
raise NotImplementedError()
def backward(self, gys):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.inputs[0].data
gx = 2 * x * gy
return gx
def square(x):
return Square()(x)
for i in range(10):
x = Variable(np.random.randn(10000)) # big data
y = square(square(square(x)))