Skip to content

Commit

Permalink
SINGA-392 Update autograd API to Pytorch style
Browse files Browse the repository at this point in the history
Change some APIs to Pytorch style.
Modified the corresponding test cases and example net.
  • Loading branch information
xuewanqi committed Sep 26, 2018
1 parent eec0d52 commit db92c75
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/autograd/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def to_categorical(y, num_classes):
x = autograd.relu(x)
x = autograd.matmul(x, w1)
x = autograd.add_bias(x, b1)
x = autograd.soft_max(x)
x = autograd.softmax(x)
loss = autograd.cross_entropy(x, target)
for p, gp in autograd.backward(loss):
sgd.apply(0, gp, p, '')
Expand Down
31 changes: 19 additions & 12 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def backward(self, dy):
return singa.DefaultTranspose(dx)


def soft_max(x, axis=0):
def softmax(x, axis=0):
return SoftMax(axis)(x)[0]


Expand Down Expand Up @@ -540,7 +540,7 @@ def backward(self, dy=1.0):
pass # TODO, broadcast elementwise multiply seems not support


def mean_square_error(x, t):
def mse_loss(x, t):
return MeanSquareError()(x, t)[0]


Expand Down Expand Up @@ -1076,7 +1076,8 @@ def backward(self, dy):
return dx1, dx2


def elemmatmul(x, y):
def mul(x, y):
# do pointwise multiplication
return ElemMatmul()(x, y)[0]


Expand All @@ -1088,7 +1089,7 @@ def add_all(*xs):
return


class RNN(Layer):
class RNN_Base(Layer):

def __init__(self):
raise NotImplementedError
Expand All @@ -1100,7 +1101,7 @@ def step_forward(self):
raise NotImplementedError


class Vanilla_RNN(RNN):
class RNN(RNN_Base):

def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False):
self.nonlinearity = nonlinearity
Expand All @@ -1119,7 +1120,10 @@ def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', b

self.params = (self.Wx, self.Wh, self.b)

def __call__(self, h0, *xs):
def __call__(self, xs, h0):
# xs: a tuple or list of input tensors
if not isinstance(xs, tuple):
xs = tuple(xs)
inputs = xs + (h0,)
self.device_check(*inputs)
#self.device_check(inputs[0], *self.params)
Expand Down Expand Up @@ -1148,7 +1152,7 @@ def step_forward(self, x, h, Wx, Wh, b):
return y


class LSTM(RNN):
class LSTM(RNN_Base):

def __init__(self, input_size, hidden_size, nonlinearity='tanh', num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
self.nonlinearity = nonlinearity
Expand Down Expand Up @@ -1183,8 +1187,11 @@ def __init__(self, input_size, hidden_size, nonlinearity='tanh', num_layers=1, b

self.params = self.Wx + self.Wh + self.Bx + self.Bh

def __call__(self, h0, c0, *xs):
inputs = xs + (h0, c0)
def __call__(self, xs, (h0, c0)):
# xs: a tuple or list of input tensors
if not isinstance(xs, list):
xs = list(xs)
inputs = xs + list((h0, c0))
self.device_check(*inputs)
#self.device_check(inputs[0], *self.params)
self.device_check(inputs[0], *(self.Wx + self.Wh + self.Bx + self.Bh))
Expand Down Expand Up @@ -1229,10 +1236,10 @@ def step_forward(self, x, h, c, Wx, Wh, Bx, Bh):
g = add(y1, y2)
g = tanh(g)

cout1 = elemmatmul(f, c)
cout2 = elemmatmul(i, g)
cout1 = mul(f, c)
cout2 = mul(i, g)
cout = add(cout1, cout2)

hout = tanh(cout)
hout = elemmatmul(o, hout)
hout = mul(o, hout)
return hout, cout
14 changes: 7 additions & 7 deletions test/python/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def test_batchnorm2d_gpu(self):
def test_vanillaRNN_gpu_tiny_ops_shape_check(self):
# gradients shape check.
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()
rnn = autograd.Vanilla_RNN(3, 2)
rnn = autograd.RNN(3, 2)

hs, _ = rnn(h0, *inputs)
hs, _ = rnn(inputs, h0)

loss = autograd.softmax_cross_entropy(hs[0], target[0])
for i in range(1, len(hs)):
Expand All @@ -162,7 +162,7 @@ def test_LSTM_gpu_tiny_ops_shape_check(self):

rnn = autograd.LSTM(3, 2)

hs, _, _ = rnn(h0, c0, *inputs)
hs, _, _ = rnn(inputs, (h0, c0))
loss = autograd.softmax_cross_entropy(hs[0], target[0])

for i in range(1, len(hs)):
Expand Down Expand Up @@ -206,10 +206,10 @@ def gradients_check(self, func, param, autograds, h=0.0005, df=1):
def test_numerical_gradients_check_for_vallina_rnn(self):
inputs, target, h0 = prepare_inputs_targets_for_rnn_test()

rnn = autograd.Vanilla_RNN(3, 2)
rnn = autograd.RNN(3, 2)

def valinna_rnn_forward():
hs, _ = rnn(h0, *inputs)
hs, _ = rnn(inputs, h0)

loss = autograd.softmax_cross_entropy(hs[0], target[0])
for i in range(1, len(hs)):
Expand All @@ -234,7 +234,7 @@ def test_numerical_gradients_check_for_lstm(self):
rnn = autograd.LSTM(3, 2)

def lstm_forward():
hs, _, _ = rnn(h0, c0, *inputs)
hs, _, _ = rnn(inputs, (h0, c0))

loss = autograd.softmax_cross_entropy(hs[0], target[0])
for i in range(1, len(hs)):
Expand All @@ -258,7 +258,7 @@ def test_MeanSquareError(self):
x.to_device(gpu_dev)
t.to_device(gpu_dev)

loss= autograd.mean_square_error(x,t)
loss= autograd.mse_loss(x,t)
dx=loss.creator.backward()[0]

loss_np=tensor.to_numpy(loss)
Expand Down

0 comments on commit db92c75

Please sign in to comment.