-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathdense_mnist.f90
68 lines (53 loc) · 1.89 KB
/
dense_mnist.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
program dense_mnist
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr
implicit none
type(network) :: net
real, allocatable :: training_images(:,:), training_labels(:)
real, allocatable :: validation_images(:,:), validation_labels(:)
integer :: n, num_epochs
call load_mnist(training_images, training_labels, &
validation_images, validation_labels)
print '("MNIST")'
print '(60("="))'
net = network([ &
input(784), &
dense(30), &
dense(10) &
])
num_epochs = 10
call net % print_info()
print '(a,f5.2,a)', 'Initial accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
epochs: do n = 1, num_epochs
call net % train( &
training_images, &
label_digits(training_labels), &
batch_size=100, &
epochs=1, &
optimizer=sgd(learning_rate=3.) &
)
block
real, allocatable :: output_metrics(:,:)
real, allocatable :: mean_metrics(:)
! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metric=corr())
mean_metrics = sum(output_metrics, 1) / size(output_metrics, 1)
print '(a,i2,3(a,f6.3))', 'Epoch ', n, ' done, Accuracy: ', &
accuracy(net, validation_images, label_digits(validation_labels)) * 100, &
'%, Loss: ', mean_metrics(1), ', Pearson correlation: ', mean_metrics(2)
end block
end do epochs
contains
real function accuracy(net, x, y)
type(network), intent(in out) :: net
real, intent(in) :: x(:,:), y(:,:)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
accuracy = real(good) / size(x, dim=2)
end function accuracy
end program dense_mnist