-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_conf_CL.py
152 lines (130 loc) · 4.63 KB
/
calc_conf_CL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python
import numpy as np
import nibabel as nib
from dipy.tracking.metrics import downsample
from dipy.tracking.distances import bundles_distances_mdf
import matplotlib.pyplot as plt
from dipy.tracking.utils import length
import sys
import os
def makehist(x,name='none',saveloc=None, show=False):
num_bins = 50
# the histogram of the data
fig, ax = plt.subplots(nrows=1, ncols=1) # create figure & 1 axis
#ax.plot([0,1,2], [10,20,3])
n, bins, patches = ax.hist(x, num_bins, facecolor='green', alpha=0.5)
# add a 'best fit' line
#y = mlab.normpdf(bins, mu, sigma)
#plt.plot(bins, y, 'r--')
#plt.xlabel('Smarts')
#plt.ylabel('Probability')
ax.set_title(r'Histogram of Confidence: %s' % name)
#fig.savefig(saveloc,dpi=100)
#plt.close(fig)
# Tweak spacing to prevent clipping of ylabel
#plt.subplots_adjust(left=0.15)
if show:
plt.show()
if saveloc:
fig.savefig(saveloc)
def showsls(sls,values,outpath,show=False):
from dipy.viz import window, actor,fvtk
from dipy.data import fetch_bundles_2_subjects, read_bundles_2_subjects
from dipy.tracking.streamline import transform_streamlines
#renderer.clear()
from dipy.tracking.streamline import length
renderer=window.Renderer()
hue = [0.5, 1] # white to purple to red
saturation = [0.0, 1.0] # black to white
lut_cmap = actor.colormap_lookup_table(
scale_range=(values.min(), np.percentile(values,50)),
hue_range=hue,
saturation_range=saturation)
stream_actor5 = actor.line(sls, values, linewidth=0.1,
lookup_colormap=lut_cmap)
renderer.add(stream_actor5)
bar3 = actor.scalar_bar(lut_cmap)
renderer.add(bar3)
# window.show(renderer, size=(600, 600), reset_camera=False)
if outpath:
window.record(renderer, out_path=outpath, size=(600, 600))
if show:
fvtk.show(renderer)
def doit(basedir, rootname, trkloc, debug=False):
savename = '%s/%s_mtrx.txt' % (basedir, rootname)
savetrk = '%s/%s.trk' % (basedir, rootname)
savetrkss = '%s/%s_SS.trk' % (basedir,rootname)
savetrkpic = savetrk.replace('.trk','.png')
savetrkpicss = savetrk.replace('.trk','SS.png')
savehist = '%s/%s_hist.png' % (basedir,rootname)
subsampN = 12
showhist=False
showsls_var=False
lenthr = 50
pow = 1
if debug:
subsampN=12
subset = 800
trk,hdr = nib.trackvis.read(trkloc)
sls = [item[0] for item in trk]
lengths = list(length(sls))
#print lengths
print "cheerio"
print len(lengths)
print len(sls)
print "two"
print lengths[0]
sls_long = []
for n,j in enumerate(sls):
if len(j) > 5:
if lengths[n]>lenthr:
print lengths[n]
sls_long.append(j)
else:
print "BAD"
print lengths[n]
print "CHECKME"
print len(sls)
print len(sls_long)
subsamp = [downsample(sl,subsampN) for sl in sls_long]
if debug:
subsamp = subsamp[0:subset]
sls_long = sls_long[0:subset]
score_mtrx = np.zeros([len(subsamp)])
print "what"
print len(subsamp)
print len(sls)
for i,sl in enumerate(subsamp):
print str(i)+'/'+str(len(subsamp))
mtrx2 = bundles_distances_mdf([sls_long[i]],sls_long)
#mtrx2_oi = np.all([(mtrx2>0) & (mtrx2<5) & (not np.isnan(mtrx2))],axis=0)
mtrx2_oi = (mtrx2>0) & (mtrx2<5) & ~np.isnan(mtrx2)
zoom = mtrx2[mtrx2_oi]
score = np.sum(np.divide(1,np.power(zoom,pow)))
#makehist(zoom.T,score)
score_mtrx[i] = score
print score
#print score_mtrx
print savename
print basedir
print rootname
np.savetxt(savename,score_mtrx)
makehist(score_mtrx,name=rootname,saveloc=savehist,show=showhist)
newhdr=hdr.copy()
newhdr['n_properties']=1
proplist=newhdr['property_name']
proplist[0]='Cluster Confidence'
newhdr['property_name']=proplist
newtrkss = ((sl,None,score_mtrx[i]) for i,sl in enumerate(subsamp))
newtrk = ((sl, None, score_mtrx[i]) for i,sl in enumerate(sls_long))
nib.trackvis.write(savetrkss, newtrkss, newhdr)
nib.trackvis.write(savetrk,newtrk,newhdr)
#showsls(subsamp,score_mtrx,savetrkpicss,show=showsls_var)
showsls(sls_long, score_mtrx, savetrkpic,show=showsls_var)
if __name__ == '__main__':
trkloc = os.path.abspath(sys.argv[1])
putloc = os.path.join(os.path.dirname(trkloc),'cci')
if os.path.exists(trkloc):
if not os.path.exists(putloc):
os.mkdir(putloc)
doit(putloc, os.path.basename(trkloc).replace('.trk','_cci'), trkloc, False)