forked from syncrostone/neurocatcher
-
Notifications
You must be signed in to change notification settings - Fork 1
/
example.py
74 lines (50 loc) · 1.8 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
## generate a batch of training data from 'data' generated by fakearray
from showit import image
import matplotlib.pyplot as plot
import numpy as np
from fakearray import calcium_imaging
from neurocatcher import data_train
data,series,truth=calcium_imaging(shape=(100,100), n=12, t=10, withparams=True)
data=data*255/data.max()
data=np.transpose(data,(1,2,0))
data=[data]
batch_data,batch_truth=data_train(data,[truth],10,20,20-6)
for i,pic in enumerate(batch_data):
image(np.mean(pic,axis=2))
plot.show()
image(batch_truth[i,:,:,0])
plot.show()
##train a network and check how it performs
import neurocatcher as nc
import fakearray as fa
from showit import image
import numpy as np
import matplotlib.pyplot as plot
# generate some faux calcium imaging data
data, series, truth = fa.calcium_imaging(withparams=True)
# we will train a network that takes the mean image as input
data = [data.mean(axis=0)[..., np.newaxis]]
truth = [truth]
# each layer is defined by (filter footprint, 3 of features)
layers = [(3, 10), (3, 10)]
# the input will be a 15-by-15 patch with a single channel
inputShape = (15, 1)
# train network
acc, network = nc.network.train_conv_net(layers, inputShape, data, truth, batch_size=30, steps=1000)
# show how the loss changes during training
plot.plot(acc[:, 0])
plot.show()
##predict output using trained network
#generate new data to predict on
data,series,truth=calcium_imaging(shape=(240,240), n=75, t=10, noise=0.0, withparams=True)
data = [data.mean(axis=0)[..., np.newaxis]]
truth = [truth]
# get a prediction from the full data set
yhat, ytarget = nc.network.predict_conv_net(network, data, truth)
# visualize the results
plot.figure(figsize=(20, 15))
plot.subplot(1, 2, 1)
plot.imshow(ytarget[0], cmap='bone')
plot.subplot(1, 2, 2)
plot.imshow(yhat[0], cmap='bone')
plot.show()