Skip to content

Commit

Permalink
Test for python forward and backward with start and end layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
cdoersch authored and shelhamer committed Apr 14, 2017
1 parent e6b2ba5 commit c19c960
Showing 1 changed file with 41 additions and 4 deletions.
45 changes: 41 additions & 4 deletions python/caffe/test/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def simple_net_file(num_output):
bias_filler { type: 'constant' value: 2 } }
param { decay_mult: 1 } param { decay_mult: 0 }
}
layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip'
layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip_blob'
inner_product_param { num_output: """ + str(num_output) + """
weight_filler { type: 'gaussian' std: 2.5 }
bias_filler { type: 'constant' value: -3 } } }
layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip' bottom: 'label'
layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip_blob' bottom: 'label'
top: 'loss' }""")
f.close()
return f.name
Expand Down Expand Up @@ -71,6 +71,43 @@ def test_forward_backward(self):
self.net.forward()
self.net.backward()

def test_forward_start_end(self):
conv_blob=self.net.blobs['conv'];
ip_blob=self.net.blobs['ip_blob'];
sample_data=np.random.uniform(size=conv_blob.data.shape);
sample_data=sample_data.astype(np.float32);
conv_blob.data[:]=sample_data;
forward_blob=self.net.forward(start='ip',end='ip');
self.assertIn('ip_blob',forward_blob);

manual_forward=[];
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data,
conv_blob.data[i].reshape(-1));
manual_forward.append(dot+self.net.params['ip'][1].data);
manual_forward=np.array(manual_forward);

np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3);

def test_backward_start_end(self):
conv_blob=self.net.blobs['conv'];
ip_blob=self.net.blobs['ip_blob'];
sample_data=np.random.uniform(size=ip_blob.data.shape)
sample_data=sample_data.astype(np.float32);
ip_blob.diff[:]=sample_data;
backward_blob=self.net.backward(start='ip',end='ip');
self.assertIn('conv',backward_blob);

manual_backward=[];
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data.transpose(),
sample_data[i].reshape(-1));
manual_backward.append(dot);
manual_backward=np.array(manual_backward);
manual_backward=manual_backward.reshape(conv_blob.data.shape);

np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3);

def test_clear_param_diffs(self):
# Run a forward/backward step to have non-zero diffs
self.net.forward()
Expand All @@ -90,13 +127,13 @@ def test_top_bottom_names(self):
self.assertEqual(self.net.top_names,
OrderedDict([('data', ['data', 'label']),
('conv', ['conv']),
('ip', ['ip']),
('ip', ['ip_blob']),
('loss', ['loss'])]))
self.assertEqual(self.net.bottom_names,
OrderedDict([('data', []),
('conv', ['data']),
('ip', ['conv']),
('loss', ['ip', 'label'])]))
('loss', ['ip_blob', 'label'])]))

def test_save_and_read(self):
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
Expand Down

0 comments on commit c19c960

Please sign in to comment.