diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py index 24391cc50c4..afd27690981 100644 --- a/python/caffe/test/test_net.py +++ b/python/caffe/test/test_net.py @@ -25,11 +25,11 @@ def simple_net_file(num_output): bias_filler { type: 'constant' value: 2 } } param { decay_mult: 1 } param { decay_mult: 0 } } - layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip' + layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip_blob' inner_product_param { num_output: """ + str(num_output) + """ weight_filler { type: 'gaussian' std: 2.5 } bias_filler { type: 'constant' value: -3 } } } - layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip' bottom: 'label' + layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip_blob' bottom: 'label' top: 'loss' }""") f.close() return f.name @@ -71,6 +71,43 @@ def test_forward_backward(self): self.net.forward() self.net.backward() + def test_forward_start_end(self): + conv_blob=self.net.blobs['conv']; + ip_blob=self.net.blobs['ip_blob']; + sample_data=np.random.uniform(size=conv_blob.data.shape); + sample_data=sample_data.astype(np.float32); + conv_blob.data[:]=sample_data; + forward_blob=self.net.forward(start='ip',end='ip'); + self.assertIn('ip_blob',forward_blob); + + manual_forward=[]; + for i in range(0,conv_blob.data.shape[0]): + dot=np.dot(self.net.params['ip'][0].data, + conv_blob.data[i].reshape(-1)); + manual_forward.append(dot+self.net.params['ip'][1].data); + manual_forward=np.array(manual_forward); + + np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3); + + def test_backward_start_end(self): + conv_blob=self.net.blobs['conv']; + ip_blob=self.net.blobs['ip_blob']; + sample_data=np.random.uniform(size=ip_blob.data.shape) + sample_data=sample_data.astype(np.float32); + ip_blob.diff[:]=sample_data; + backward_blob=self.net.backward(start='ip',end='ip'); + self.assertIn('conv',backward_blob); + + manual_backward=[]; + for i in range(0,conv_blob.data.shape[0]): + dot=np.dot(self.net.params['ip'][0].data.transpose(), + sample_data[i].reshape(-1)); + manual_backward.append(dot); + manual_backward=np.array(manual_backward); + manual_backward=manual_backward.reshape(conv_blob.data.shape); + + np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3); + def test_clear_param_diffs(self): # Run a forward/backward step to have non-zero diffs self.net.forward() @@ -90,13 +127,13 @@ def test_top_bottom_names(self): self.assertEqual(self.net.top_names, OrderedDict([('data', ['data', 'label']), ('conv', ['conv']), - ('ip', ['ip']), + ('ip', ['ip_blob']), ('loss', ['loss'])])) self.assertEqual(self.net.bottom_names, OrderedDict([('data', []), ('conv', ['data']), ('ip', ['conv']), - ('loss', ['ip', 'label'])])) + ('loss', ['ip_blob', 'label'])])) def test_save_and_read(self): f = tempfile.NamedTemporaryFile(mode='w+', delete=False)