-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathtest_insert_flatten.f90
64 lines (50 loc) · 1.44 KB
/
test_insert_flatten.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
program test_insert_flatten
use iso_fortran_env, only: stderr => error_unit
use nf, only: network, input, conv2d, maxpool2d, flatten, dense, reshape
implicit none
type(network) :: net
logical :: ok = .true.
net = network([ &
input([3, 32, 32]), &
dense(10) &
])
if (.not. net % layers(2) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after input3d.. failed'
end if
net = network([ &
input([3, 32, 32]), &
conv2d(filters=1, kernel_size=3), &
dense(10) &
])
!call net % print_info()
if (.not. net % layers(3) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after conv2d.. failed'
end if
net = network([ &
input([3, 32, 32]), &
conv2d(filters=1, kernel_size=3), &
maxpool2d(pool_size=2, stride=2), &
dense(10) &
])
if (.not. net % layers(4) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after maxpool2d.. failed'
end if
net = network([ &
input(4), &
reshape([1, 2, 2]), &
dense(4) &
])
if (.not. net % layers(3) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after reshape.. failed'
end if
if (ok) then
print '(a)', 'test_insert_flatten: All tests passed.'
else
write(stderr, '(a)') 'test_insert_flatten: One or more tests failed.'
stop 1
end if
end program test_insert_flatten