Skip to content

Commit

Permalink
Update tensor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mankasto authored Nov 27, 2020
1 parent 77eadf9 commit b08383e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, data, gpu=None):
self.cuda_()

# internal variables used for autograd graph construction
# Tensor列表
# _ctx是一个Function,当前环境上下文是在哪个函数内部
self._ctx = None

def __repr__(self):
Expand Down Expand Up @@ -147,15 +147,19 @@ def deepwalk(node):

for t0 in reversed(nodes):
assert (t0.grad is not None)
# __class__.__name__ 类名
# 在调试环境下进行反向传播(梯度计算)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True):
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
# grads放到列表中,以便下面应用zip方法
if len(t0._ctx.parents) == 1:
grads = [grads]
for t,g in zip(t0._ctx.parents, grads):
if g is None:
continue
assert g.shape == t.data.shape, \
"grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape)
# 在backward中作了链式的乘法运算,这里将不同parents对该变量的梯度(偏导数)再求和
t.grad = Tensor(g) if t.grad is None else (t.grad + Tensor(g))

# ***** tinygrad supports CPU and GPU *****
Expand Down Expand Up @@ -202,6 +206,7 @@ def detach(self):
opsgpu = {}

# ***** non first class ops *****
# ***** 一些稍复杂的算子 ****

def mean(self):
div = Tensor(np.array([1/np.prod(self.shape)], dtype=self.data.dtype), gpu=self.gpu)
Expand Down Expand Up @@ -229,7 +234,7 @@ def __init__(self, *tensors):
self.parents = tensors
self.saved_tensors = []

# 保存x到saved_tensors列表中
# 保存变量x(参数)到saved_tensors列表中
def save_for_backward(self, *x):
self.saved_tensors.extend(x)

Expand All @@ -245,6 +250,7 @@ def apply(self, *x, **kwargs):
# overwrite with passed params
for k, v in kwargs.items():
setattr(ctx, k, v)
# 前向运算
with ProfileOp(ctx.__class__.__name__, x):
ret = Tensor(op.forward(ctx, *[t.data for t in x], **kwargs))
ret._ctx = ctx
Expand Down

0 comments on commit b08383e

Please sign in to comment.