-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_getitem.py
44 lines (36 loc) · 1.34 KB
/
test_getitem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import unittest
import numpy as np
from dezero import Variable
import dezero.functions as F
from dezero.utils import gradient_check, array_allclose
class TestGetitem(unittest.TestCase):
def test_forward1(self):
x_data = np.arange(12).reshape((2, 2, 3))
x = Variable(x_data)
y = F.get_item(x, 0)
self.assertTrue(array_allclose(y.data, x_data[0]))
def test_forward1a(self):
x_data = np.arange(12).reshape((2, 2, 3))
x = Variable(x_data)
y = x[0]
self.assertTrue(array_allclose(y.data, x_data[0]))
def test_forward2(self):
x_data = np.arange(12).reshape((2, 2, 3))
x = Variable(x_data)
y = F.get_item(x, (0, 0, slice(0, 2, 1)))
self.assertTrue(array_allclose(y.data, x_data[0, 0, 0:2:1]))
def test_forward3(self):
x_data = np.arange(12).reshape((2, 2, 3))
x = Variable(x_data)
y = F.get_item(x, (Ellipsis, 2))
self.assertTrue(array_allclose(y.data, x_data[..., 2]))
def test_backward1(self):
x_data = np.array([[1, 2, 3], [4, 5, 6]])
slices = 1
f = lambda x: F.get_item(x, slices)
gradient_check(f, x_data)
def test_backward2(self):
x_data = np.arange(12).reshape(4, 3)
slices = slice(1, 3)
f = lambda x: F.get_item(x, slices)
gradient_check(f, x_data)