forked from dmlc/gluon-nlp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_gluon_block.py
91 lines (75 loc) · 2.59 KB
/
test_gluon_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pytest
import mxnet as mx
from mxnet import nd, np, npx
from mxnet.test_utils import assert_allclose
from mxnet.gluon import HybridBlock, Constant
from mxnet.gluon.data import DataLoader
import itertools
mx.npx.set_np()
def test_const():
class Foo(HybridBlock):
def __init__(self):
super().__init__()
self.weight = Constant(np.ones((10, 10)))
def forward(self, x, weight):
return x, weight.astype(np.float32)
foo = Foo()
foo.hybridize()
foo.initialize()
def test_scalar():
class Foo(HybridBlock):
def forward(self, x):
return x * x * 2
foo = Foo()
foo.hybridize()
foo.initialize()
out = foo(mx.np.array(1.0))
assert_allclose(out.asnumpy(), np.array(2.0))
def test_gluon_nonzero_hybridize():
class Foo(HybridBlock):
def __init__(self):
super().__init__()
def forward(self, x):
dat = npx.nonzero(x)
return dat.sum() + dat
foo = Foo()
foo.hybridize()
out = foo(mx.np.array([1, 0, 2, 0, 3, 0]))
out.wait_to_read()
out = foo(mx.np.array([0, 0, 0, 0, 0, 0]))
out.wait_to_read()
@pytest.mark.xfail(reason='Expected to fail due to MXNet bug https://github.com/apache/'
'incubator-mxnet/issues/19659')
def test_gluon_boolean_mask():
class Foo(HybridBlock):
def forward(self, data, indices):
mask = indices < 3
data = npx.reshape(data, (-1, -2), reverse=True)
mask = np.reshape(mask, (-1,))
sel = nd.np._internal.boolean_mask(data, mask)
return sel
data = mx.np.random.normal(0, 1, (5, 5, 5, 5, 16))
indices = mx.np.random.randint(0, 5, (5, 5, 5, 5))
data.attach_grad()
indices.attach_grad()
foo = Foo()
foo.hybridize()
with mx.autograd.record():
out = foo(data, indices)
out.backward()
out.wait_to_read()
def test_basic_dataloader():
def grouper(iterable, n, fillvalue=None):
"""Collect data into fixed-length chunks or blocks"""
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)
ctx_l = [mx.cpu(i) for i in range(8)]
dataset = [mx.np.ones((2,)) * i for i in range(1000)]
dataloader = DataLoader(dataset, 2, num_workers=4, prefetch=10)
for i, data_l in enumerate(grouper(dataloader, len(ctx_l))):
for data, ctx in zip(data_l, ctx_l):
if data is None:
continue
data = data.as_in_ctx(ctx)
mx.npx.waitall()