forked from vipasu/addseds
-
Notifications
You must be signed in to change notification settings - Fork 0
/
routines.py
971 lines (813 loc) · 33.9 KB
/
routines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
import numpy as np
import matplotlib.pyplot as plt
from fast3tree import fast3tree
from scipy.stats import chisquare
from sklearn import preprocessing
import seaborn as sns
#sns.set_context('poster')
#sns.set(font_scale=3, style='whitegrid')
from CorrelationFunction import projected_correlation
from sklearn.tree import DecisionTreeRegressor
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import curve_fit
from itertools import chain
from collections import defaultdict
import os
import errno
import pandas as pd
import matplotlib.ticker
sns.set(font_scale=3.5, rc={'xtick.labelsize': 25,'ytick.labelsize': 25,'legend.fontsize': 25})
#sns.set_style('whitegrid')
sns.set_style('ticks')
# Box parameters
h = 0.7
L = 250.0/h
zmax = 40.0
box_size = 250.0
# Parameters used for correlation function plots
rpmin = 0.1
rpmax = 20.0
Nrp = 25
rbins = np.logspace(np.log10(rpmin), np.log10(rpmax), Nrp+1)
r = np.sqrt(rbins[1:]*rbins[:-1])
png = '.png'
red_col, blue_col = sns.xkcd_rgb['reddish'], sns.xkcd_rgb['blue']
########################################
# Tools for smaller physics calculations
########################################
def get_distance(center, pos, box_size=-1):
"""
Computes distance between points in 1D.
Parameters
----------
center : float
central point
pos : array-like
other points computing distance to
box_size : float (Delfault: -1)
if box_size > 0, assumes periodic BCs
Returns
-------
d : array-like
distance between center and pos
"""
d = pos - np.array(center)
if box_size > 0:
half_box_size = box_size*0.5
d[d > half_box_size] -= box_size
d[d < -half_box_size] += box_size
return d
def get_nearest_nbr_periodic(center, tree, box_size, num_neighbors=1,
exclude_self=False):
"""
This function is massively inefficient. It makes two calls to the
tree code because right now the tree code only returns the distance
to the nearest neighbor, not the index or a pointer. The first call
gets a fiducial distance in the primary simulation image. Then you
find all points within that distance in all images and get the closest.
Modified to be able to return the nth_nearest object specified by
num_neighbors.
"""
half_box_size = box_size/2.0
tree.set_boundaries(0.0, box_size) ##!! important
rfid = tree.query_nearest_distance(center)
if rfid == 0.0:
rfid = box_size/np.power(tree.data.shape[0], 1.0/3.0)*10.0
if rfid > half_box_size:
rfid = half_box_size - 2e-6
rfid += 1e-6
while True:
assert rfid < half_box_size
idx, pos = tree.query_radius(center, rfid, periodic=box_size, output='both')
# if len(idx) <= 1:
if len(idx) <= num_neighbors:
rfid *= 2.0
else:
break
dx = get_distance(center[0], pos[:, 0], box_size=box_size)
dy = get_distance(center[1], pos[:, 1], box_size=box_size)
dz = get_distance(center[2], pos[:, 2], box_size=box_size)
r2 = dx*dx + dy*dy + dz*dz
if exclude_self:
msk = r2 > 0.0
r2 = r2[msk]
idx = idx[msk]
if num_neighbors < 0:
q = np.argsort(r2)
else:
q = np.argsort(r2)[num_neighbors - 1]
return np.sqrt(r2[q]), idx[q]
## TODO: remove this
def count_neighbors_within_r(center, tree, box_size, r):
"""
Queries the tree for all objects within r of a given center and returns the
count of objects
"""
half_box_size = box_size/2.0
tree.set_boundaries(0.0, box_size)
rfid = tree.query_nearest_distance(center)
#if rfid == 0.0:
#rfid = box_size/np.power(tree.data.shape[0], 1.0/3.0)*10.0
#if rfid > half_box_size:
#rfid = half_box_size - 2e-6
#rfid += 1e-6
if r > half_box_size:
r = half_box_size
idx, pos = tree.query_radius(center, r, periodic=box_size, output='both')
return len(idx) - 1 #exclude self
def calculate_xi(cat):
"""
Given a catalog of galaxies, compute the correlation function using
approriate helper functions from CorrelationFunction.py
"""
rbins = np.logspace(np.log10(rpmin), np.log10(rpmax), Nrp+1)
pos = np.zeros((len(cat), 3), order='C')
pos[:, 0] = cat['x']/h
pos[:, 1] = cat['y']/h
pos[:, 2] = cat['z']/h + cat['vz']/h/100.0
xi, cov = projected_correlation(pos, rbins, zmax, L, jackknife_nside=3)
return xi, cov
def calculate_r_hill(halo_set, massive_halos, rmax=5):
"""
calculate_r_hill iterates over all halos in the halo set. For each halo,
calculate the rhill generated by the 10 more massive nearest neighbors in
massive_halos, save the r_hill min and corresponding distance and mass of
the determining neighbor halo.
"""
half_box_size = box_size/2
massive_halos = massive_halos.reset_index(drop=True)
halo_set = halo_set.reset_index(drop=True)
r_hills = []
halo_dists = []
halo_masses = []
r_hills = []
# Iterate over the halos and compare
for i, halo in halo_set.iterrows():
if i % 5000 == 0:
print i
m_sec = halo['mvir']
center = [halo['x'], halo['y'], halo['z']]
larger_halos = massive_halos[massive_halos['mvir'] > m_sec] #.values?
pos = np.zeros((len(larger_halos), 3))
for i, tag in enumerate(['x', 'y', 'z']):
pos[:, i] = larger_halos[tag][:]
num_tries = 0
with fast3tree(pos) as tree:
tree.set_boundaries(0.0, box_size)
# First find the nearest neighbor to get a sense of the density
rmax = tree.query_nearest_distance(center) + 5
if rmax > half_box_size:
rmax = half_box_size - 1e-6
while True:
if num_tries > 3:
break
idxs, pos = tree.query_radius(center, rmax, periodic=box_size, output='both')
# repeat with a larger radius if there are fewer than 10 neighbors
if len(idxs) < 10:
rmax *= 3.0
num_tries += 1
else:
break
dx = get_distance(center[0], pos[:, 0], box_size=box_size)
dy = get_distance(center[1], pos[:, 1], box_size=box_size)
dz = get_distance(center[2], pos[:, 2], box_size=box_size)
r2 = dx*dx + dy*dy + dz*dz
msk = r2 > 0.0
rs = np.sqrt(r2[msk])
idxs = idxs[msk]
if len(rs) > 0:
rhill_candidates = [r * (m_sec/(3 * massive_halos['mvir'][idx]))**(1./3) for r, idx in zip(rs, idxs)]
rhill = min(rhill_candidates)
idx_idx = np.argmin(rhill_candidates)
halo_dist = rs[idx_idx]
halo_mass = massive_halos['mvir'][idxs[idx_idx]]
else:
rhill = half_box_size
halo_dist = np.nan
halo_mass = np.nan
r_hills.append(rhill)
halo_dists.append(halo_dist)
halo_masses.append(halo_mass)
return r_hills, halo_dists, halo_masses
def r_hill_pdfs(rhills, dists, masses):
dists = np.array(dists)
masses = np.array(masses)
rhills = np.array(rhills)
plt.figure()
sns.distplot(np.log10(rhills[~np.isnan(rhills)]), kde=False, norm_hist=True)
ax = plt.gca()
ax.set_yscale('log')
plt.ylabel('$PDF$')
plt.xlabel('$Log(Rhill)$')
labels = [item.get_text() for item in ax.get_xticklabels()]
labels = ['$10^{' + str(label) + '}$' for label in np.arange(-5,4)]
#ax.set_xticklabels(labels)
plt.figure()
sns.distplot(np.log10(dists[~np.isnan(dists)]), kde=False, norm_hist=True)
ax = plt.gca()
ax.set_yscale('log')
labels = [item.get_text() for item in ax.get_xticklabels()]
labels = ['$10^{' + str(label) + '}$' for label in np.arange(-4,4)]
#ax.set_xticklabels(labels)
plt.ylabel('$PDF$')
plt.xlabel('$Log(r_{Halo_{Rhill}})$')
plt.figure()
sns.distplot(np.log10(masses[~np.isnan(masses)]), kde=False, norm_hist=True)
ax = plt.gca()
ax.set_yscale('log')
plt.ylabel('$PDF$')
plt.xlabel('$Log(M_{Halo_{Rhill}})$')
labels = [item.get_text() for item in ax.get_xticklabels()]
labels = ['$10^{' + str(label) + '}$' for label in np.arange(9,17)]
#ax.set_xticklabels(labels)
########################################
# Tools for exploratory data analysis
########################################
# code to do nearest nbr's with fast3tree
# pull code from here
# https://bitbucket.org/beckermr/fast3tree
def catalog_selection(d0, m, msmin, msmax):
"""
Create parent set with mvir > m
and galaxy set with mstar in the range (msmin, msmax)
"""
# get parent halo set
dp = d0[(d0['upid'] == -1) & (d0['mvir'] >= m)]
dp = dp.reset_index(drop=True)
# make stellar mass bin
d = d0[(d0['mstar'] >= msmin) & (d0['mstar'] <= msmax)]
d = d.reset_index(drop=True)
return d, dp
def get_dist_and_attrs(dp, d, nn, attrs):
"""
dp - parent set of halos
d - galaxy set
nn - num neighbors
attrs - list of attributes (i.e. ['mvir','vmax']
"""
pos = np.zeros((len(dp), 3))
for i, tag in enumerate(['x', 'y', 'z']):
pos[:, i] = dp[tag][:]
dnbr = np.zeros(len(d))
res = [np.zeros(len(d)) for attr in attrs]
with fast3tree(pos) as tree:
for i in xrange(len(d)):
if i % 10000 == 0: print i, len(d)
center = [d['x'].values[i], d['y'].values[i], d['z'].values[i]]
r, ind = get_nearest_nbr_periodic(center, tree, box_size, num_neighbors=nn, exclude_self=True)
dnbr[i] = np.log10(r)
for j, attr in enumerate(attrs):
res[j][i] = dp[attr].values[ind]
return dnbr, res
def scale_data(data):
scaler = preprocessing.StandardScaler().fit(data)
return scaler.transform(data), scaler
def split_octant(d, box_size):
d_octant = d[(d['x'] < box_size/2) & (d['y'] < box_size/2) & (d['z'] < box_size/2)]
d_rest = d[~d.index.isin(d_octant.index)]
d_octant = d_octant.reset_index(drop=True)
d_rest = d_rest.reset_index(drop=True)
return d_octant, d_rest
def select_features(features, dataset, scaled=True):
x_cols = [dataset[feature].values for feature in features]
Xtot = np.column_stack(x_cols)
y = dataset['ssfr'].values
if scaled:
Xtot, x_scaler = scale_data(Xtot)
y, y_scaler = scale_data(y)
return Xtot, y, x_scaler, y_scaler
else:
return Xtot, y
def pre_process(features, target, d, seed=5432):
"""
Given features x1 and x2, and the galaxy catalog d, shuffle the data and
then split into train and testing sets.
"""
# machine learning bit
N_points = len(d)
# get the features X and the outputs y
Xtot = np.column_stack(features)
Xtot, x_scaler = scale_data(Xtot)
y, y_scaler = scale_data(target)
np.random.seed(seed)
shuffle = np.random.permutation(N_points)
Xtot = Xtot[shuffle]
y = y[shuffle]
split = int(N_points * .5)
Xtrain, Xtest = Xtot[:split, :], Xtot[split:, :]
ytrain, ytest = y[:split], y[split:]
d_train, d_test = d.ix[shuffle[:split]], d.ix[shuffle[split:]]
return Xtrain, Xtest, ytrain, ytest, d_train, d_test, x_scaler, y_scaler
def plot_differences_3d(x1, x2, predicted, actual, name):
"""
Given two arrays of input data, create a 3D scatterplot with marker sizes
proportional to the difference in predicted and actual values.
"""
num_samples = len(x1)
diffs = predicted - actual
fig = plt.figure(1, figsize=(12, 8))
ax = fig.add_subplot(121, projection='3d')
ax.scatter(x1, x2, predicted, c='r')
ax.scatter(x1, x2, actual, c='b', alpha=0.3)
plt.title(name + " Actual vs Predicted ({})".format(num_samples))
plt.legend(["Predicted", "Actual"], fontsize=6, loc='lower left')
ax.set_xlabel('distance')
ax.set_ylabel('mvir')
ax.set_zlabel('ssfr')
# Scatterplot with size indicating how far it is from the true estimation
ax = fig.add_subplot(122, projection='3d')
ax.scatter(x1, x2, predicted, c=red_col, s=5 * np.exp(2 * np.abs(diffs)), alpha=0.7)
plt.title(name + " Predictions with errors")
ax.set_xlabel('distance')
ax.set_ylabel('mvir')
ax.set_zlabel('ssfr')
def plot_kdes(predicted, actual, name):
"""
Individual smooth histograms of ssfr distribution. Ensures that the number
of red vs blue galaxies is reasonable.
"""
fig = plt.figure()
with sns.color_palette("pastel"):
sns.kdeplot(actual, shade=True, label='input')
sns.kdeplot(predicted, shade=True, label='Predicted')
title = 'KDE of ssfr ({})'.format(name)
plt.title(title)
plt.xlabel('Scaled ssfr Value')
plt.savefig(image_prefix + title + png)
def cross_scatter_plot(predicted, actual, name=''):
"""
Heatmap of how well individual predictions do. Correlation coefficient r
is included in the plot.
"""
lims = [min(actual), max(actual)]
print lims
g = sns.jointplot(actual, predicted, color=sns.xkcd_rgb['jade'], xlim=lims,
ylim=lims, kind='hex')
g.set_axis_labels("Actual", "Predicted")
plt.colorbar()
plt.savefig(image_prefix + 'Scatter ' + name + png)
def sample_model(model, name, Xtrain, ytrain, Xtest, ytest, num_samples=500):
"""
Combines a handful of the sanity checks and plots.
"""
model.fit(Xtrain, ytrain)
sel = np.random.permutation(len(Xtest))[:num_samples]
# dist, mvir = Xtest[sel, 0], Xtest[sel, 1]
s = ytest[sel]
# Plot the predictions and compare
results = model.predict(Xtest[sel])
# diffs = results - s
# plot_differences_3d(dist, mvir, results, s, name)
name += ' ' + str(len(sel))
plot_kdes(results, s, name)
cross_scatter_plot(results, s, name)
# Scatterplot comparison between actual and predicted
plt.show()
def y_tick_formatter(x, pos):
s = '%s' % Decimal("%.1g" % x)
return s
########################################
# Tests aka Heavy Lifters
########################################
def plot_wprp(actual_xis, actual_cov, pred_xis, pred_cov, set_desc, num_splits):
"""
Plots calculated values of the correlation function and error bars
as well as a secondary plot with a power fit for each group
"""
n_groups = len(actual_xis)
# create a range of colors from red to blue
colors = sns.blend_palette([red_col, blue_col], n_groups)
fig = plt.figure()
ax = plt.gca()
ax.set_xscale("log")
ax.set_yscale('log')
for i, xi_pred, cov_pred, xi_actual, cov_actual in \
zip(xrange(n_groups), pred_xis, pred_cov, actual_xis, actual_cov):
print str(i) + 'th bin'
print 'chi square is:', chisquare(xi_pred, xi_actual)
var1 = np.sqrt(np.diag(cov_pred))
var2 = np.sqrt(np.diag(cov_actual))
plt.errorbar(r, xi_actual, var2, fmt='-o', label=str(i+1), color=colors[i])
plt.errorbar(r, xi_pred, var1, fmt='--o', color=colors[i], alpha=0.6)
y_format = matplotlib.ticker.FuncFormatter(y_tick_formatter)
ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.tick_params(pad=20)
#plt.ticklabel_format(axis='y', style='plain')
title = 'wp(rp) for ' + set_desc
#plt.title(title)
plt.xlabel('$r$ $[Mpc$ $h^{-1}]$')
plt.xlim(1e-1, 30)
plt.ylabel('$w_p(r_p)$')
#plt.legend()
# Fits power laws of the form c(x^-1.5) + y0
plt.figure()
plt.subplot(121)
plt.hold(True)
ax = plt.gca()
ax.set_xscale("log")
ax.set_yscale('log')
fit_region, = np.where(r > 2)
r_fit = r[fit_region]
normalizations = []
for i, xi_pred in zip(xrange(num_splits + 1), pred_xis):
popt, pcov = curve_fit(fixed_power_law, r_fit, xi_pred[fit_region],
p0= [500,20])
normalizations.append(popt[1])
plt.plot(r_fit, fixed_power_law(r_fit, *popt), color=colors[i], label=str(i+1))
plt.legend()
plt.subplot(122)
sns.barplot(np.arange(1,num_splits + 2), np.array(normalizations), palette=colors)
plt.savefig(image_prefix + title + png)
plt.show()
return
def wprp_comparison(gals, set_desc, num_splits, mix=False):
"""
Takes in a data frame of galaxies and bins galaxies by ssfr. The CF is
calculated for both predicted and actual ssfr values and passed to a helper
function for plotting.
"""
plt.figure()
percentiles = [np.round(100. * i/(num_splits + 1)) for i in xrange(1, num_splits + 1)]
splits = [np.percentile(gals['ssfr'].values, p) for p in percentiles]
# Create empty lists for 2pt functions of multiple groups
pred_xis, pred_cov = [], []
actual_xis, actual_cov = [], []
# When passed a subset of galaxies, compute the CF and cov and append it
# to the lists passed as arguments
def append_xi_cov(xi_list, cov_list, subset):
xi, cov = calculate_xi(subset)
xi_list.append(xi)
cov_list.append(cov)
if mix:
# keep a list of subsets to shuffle later
pred_cats = []
append_xi_cov(pred_xis, pred_cov, gals[gals['pred'] < splits[0]])
append_xi_cov(actual_xis, actual_cov, gals[gals['ssfr'] < splits[0]])
for i in xrange(1, len(splits)):
cat_sub_pred = gals[(gals['pred'] < splits[i]) & (gals['pred'] > splits[i-1])]
cat_sub_actual = gals[(gals['ssfr'] < splits[i]) & (gals['ssfr'] > splits[i-1])]
if mix:
pred_cats.append(cat_sub_pred)
else:
append_xi_cov(pred_xis, pred_cov, cat_sub_pred)
append_xi_cov(actual_xis, actual_cov, cat_sub_actual)
if mix:
# may have to test for functionality since switching to pandas
pred_cats.append(gals[gals['pred'] > splits[-1]])
collection = np.concatenate(pred_cats)
perm = np.random.permutation(len(collection))
idx = 0
for pred_cat in pred_cats:
pred_cat = collection[perm[idx:idx + len(pred_cat)]]
idx += len(pred_cat)
append_xi_cov(pred_xis, pred_cov, pred_cat)
append_xi_cov(pred_xis, pred_cov, gals[gals['pred'] > splits[-1]])
append_xi_cov(actual_xis, actual_cov, gals[gals['ssfr'] > splits[-1]])
# TODO: comment this out
#plot_wprp(actual_xis, actual_cov, pred_xis, pred_cov, set_desc, num_splits)
return actual_xis, actual_cov, pred_xis, pred_cov
def wprp_fraction(gals, set_desc):
actual_xis, actual_cov, pred_xis, pred_cov = wprp_comparison(gals, set_desc, 1)
# how to do the error, is it just the sum?
combined_actual = actual_xis[0]/actual_xis[1]
combined_pred = pred_xis[0]/pred_xis[1]
return combined_actual, combined_pred
def plot_richness_scatter(gals, name, full_set):
log_counts_a, scatter_a = richness_scatter(gals[gals['ssfr'] < -11.0], full_set)
log_counts_p, scatter_p = richness_scatter(gals[gals['pred'] < -11.0], full_set)
#fig1 = plt.figure(figsize=(12,7))
#frame1=fig1.add_axes((.1,.3,.8,.6))
#plt.subplot(121)
plt.plot(log_counts_a, scatter_a, 'o', label='input', color='k', markersize=7)
plt.plot(log_counts_p, scatter_p, 'o', label='predicted', color=red_col, markersize=7)
#plt.title('Scatter in richness ' + name)
plt.xlabel('Log Number of red satellites')
plt.xlabel('$<log N_{red sat}>$')
plt.xlim(-.1,2.6)
plt.ylim(0, np.max([np.nanmax(scatter_a),np.nanmax(scatter_p)]) +.1)
plt.ylabel('Scatter in $M_{halo}$')
plt.legend(loc='best')
#plt.subplot(122)
#frame2=fig1.add_axes((.1,.1,.8,.2))
#series_a = pd.Series(scatter_a, index=counts_a)
#series_p = pd.Series(scatter_p, index=counts_p)
# scat_diff = (series_a - series_p)/series_a
#scat_ratio = series_p/series_a
#plt.plot(scat_diff.index, scat_diff.values, 'ob')
#plt.plot(scat_ratio.index, scat_ratio.values, 'ob')
#plt.title("Scatter ratios in richness for actual vs predicted")
#plt.axhline(0)
#plt.ylabel('Error')
#plt.xlabel('Number of red satellites')
return
def richness_scatter(gals, full):
mass_id_dict = {}
hosts = full[full['upid'] == -1]
for _, host in hosts.iterrows():
mass_id_dict[host['id']] = host['mvir']
subs = gals[gals['upid'] != -1]
halo_children = subs.groupby('upid').count() # number of satellites which share the same host
halo_ids = halo_children.index.values # id of the host
halo_masses = pd.Series(halo_ids).map(mass_id_dict) # find the mass of the
num_children = halo_children.values.T[0]
#plt.plot(num_children)
plt.ylim(0,10)
log_num_children = np.log10(num_children)
# Potential to make more generalized version
# Main priority is to be able to distinguish between log(2)
# and log(3)
nbins = 26
bins = np.linspace(-0.05,2.45,nbins)
bin_ids = np.digitize(log_num_children, bins, right=True)
bincounts = np.bincount(bin_ids) # to know length of data in each bin
scatters = np.ones(nbins-1) * np.nan
nb_in_use = len(bincounts)
# Relatively inefficient.. but doesn't require multidimensional arrays of different lengths
for i, count in zip(range(1,nb_in_use+1), bincounts[1:]):
if count == 0:
continue
data = np.zeros(count+1)
j = 0
for bin_id,mass in zip(bin_ids, halo_masses):
if bin_id == i:
data[j] = mass
j += 1
data = pd.Series(data)
scatters[i-1] = data.std() /data.mean()
return (bins[1:] + bins[:-1])/2, scatters
def correlation_ratio(d_test, name):
c_actual, c_pred = wprp_fraction(d_test, name + 'mvir + dist')
plt.semilogx(r, c_actual, lw=4, label='input', color='k')
plt.plot(r, c_pred, '--', label='predicted', color=blue_col)
#plt.title('wprp quenched vs starforming ' + name)
plt.xlabel('$r$')
plt.ylabel('$w_p, Q / w_p, SF$')
plt.legend()
# TODO: test by using test_gals == d_gals
def plot_density_profile(d0, d_gals, test_gals, name):
"""
Binning by mass and radius, this function makes use of the
count_neighbors_within_r helper to show radial profiles of parent galaxies.
"""
rmin, rmax, Nrp = 0.1, 5.0, 10
rbins = np.logspace(np.log10(rpmin), np.log10(rpmax), Nrp+1)
r = np.sqrt(rbins[1:]*rbins[:-1])
dp = d0[d0['upid'] == -1]
# Set up a tree with the galaxy set
dp = dp.reset_index(drop=True)
d_gals = d_gals.reset_index(drop=True)
pos = np.zeros((len(d_gals), 3))
for i, tag in enumerate(['x', 'y', 'z']):
pos[:, i] = d_gals[tag][:]
mvir = dp['mvir'].values
mmin, mmax = np.min(mvir), np.max(mvir)
nmbins = 3
mbins = np.logspace(np.log10(mmin), np.log10(mmax), nmbins+1)
num_halos, _ = np.histogram(mvir, mbins)
dp['mbin_idx'] = np.digitize(mvir, mbins)
with fast3tree(pos) as tree:
tree.set_boundaries(0.0, box_size)
for i in xrange(nmbins):
mass_select = dp[dp['mbin_idx'] == i]
num_blue_actual = np.zeros(len(rbins))
num_blue_pred = np.zeros(len(rbins))
num_red_actual = np.zeros(len(rbins))
num_red_pred = np.zeros(len(rbins))
# change the order of the loop after querying the radius
for j, halo in mass_select.iterrows():
center = halo['x'], halo['y'], halo['z']
idxs, pos = tree.query_radius(center, rmax, periodic=box_size, output='both')
dx = get_distance(center[0], pos[:, 0], box_size=box_size)
dy = get_distance(center[1], pos[:, 1], box_size=box_size)
dz = get_distance(center[2], pos[:, 2], box_size=box_size)
r2 = dx*dx + dy*dy + dz*dz
msk = r2 > 0.0
q = np.argsort(r2[msk])
rs = np.sqrt(r2[msk][q])
idxs = idxs[msk][q]
for dist, sat_idx in zip(rs, idxs):
#print dist
#print rbins
rbin = np.digitize([dist], rbins)
# query for large radius and then do processing in here
if d_gals['ssfr'].values[sat_idx] < -11:
num_red_actual[rbin] += 1
else:
num_blue_actual[rbin] += 1
test_gal = test_gals[test_gals['id'] == d_gals['id'].values[sat_idx]]
if len(test_gal):
if test_gal['pred'].values[0] < -11:
num_red_pred[rbin] += 1
else:
num_blue_pred[rbin] += 1
volumes = [4./3 * np.pi * r**3 for r in rbins]
num_red_actual /= num_halos[i]
num_blue_actual /= num_halos[i]
num_red_pred /= num_halos[i]
num_blue_pred /= num_halos[i]
num_red_actual /= volumes
num_blue_actual /= volumes
num_red_pred /= volumes
num_blue_pred /= volumes
plt.figure(i)
plt.loglog(rbins, num_red_actual, color=red_col, lw=4, label='input')
plt.loglog(rbins, num_red_pred, color=red_col, label='pred', alpha=0.5)
plt.loglog(rbins, num_blue_actual, color=blue_col, lw=4, label='input')
plt.loglog(rbins, num_blue_pred, color=blue_col, label='pred', alpha=0.5)
plt.legend(loc='best')
plt.xlabel('distance')
plt.ylabel('<centrals + satellites>')
plt.title('Radial density {:.2f}'.format(mbins[i]) + ' < mvir < {:.2f}'.format(mbins[i+1]))
return
def plot_HOD(d0, test_gals, name, msmin, msmax=None):
# TODO: predictions are only on a limited bin not the entire range
# response: it's okay to plot on a range
# response: plot using all the predictors
"""
Currently iterates through parent catalog and counts how many centrals are
red/blue in a given mass bin.
"""
# need full set of upid==-1, dp should only be used for the training
# dp = pd.DataFrame.from_records(dp.byteswap().newbyteorder())
# if type(d) != pd.core.frame.DataFrame:
# d = pd.DataFrame.from_records(d.byteswap().newbyteorder())
mvir = d0['mvir'].values
red_cut = -11.0
# create equal spacing on log scale
log_space = np.arange(np.log10(np.min(mvir)), np.log10(np.max(mvir)),.2)
edges = 10**log_space
centers = 10 ** ((log_space[1:] + log_space[:-1])/2)
halos = d0[d0['upid'] == -1]
centrals = halos[halos['mstar'] > msmin]
satellites = d0[(d0['upid'] != -1) & (d0['mstar'] > msmin)]
if msmax:
satellites = satellites[satellites['mstar'] < msmax]
centrals = centrals[centrals['mstar'] < msmax]
# count the number of parents in each bin
num_halos, _ = np.histogram(halos['mvir'], edges)
num_halos = num_halos.astype(np.float)
nbins = len(centers)
# create map from upid to host mass to bin
halo_id_to_bin = {}
for _, halo in halos.iterrows():
bin_id = np.digitize([halo['mvir']], edges, right=True)[0]
halo_id_to_bin[halo['id']] = min(bin_id, nbins-1)
num_actual_blue_s, num_actual_red_s = np.zeros(nbins), np.zeros(nbins)
satellite_bins = pd.Series(satellites['upid']).map(halo_id_to_bin)
satellite_cols = pd.Series(satellites['ssfr']).map(lambda x: x < red_cut)
for bin_id, red in zip(satellite_bins, satellite_cols):
if not np.isnan(bin_id):
if red:
num_actual_red_s[bin_id] += 1
else:
num_actual_blue_s[bin_id] += 1
num_actual_blue_c, num_actual_red_c = np.zeros(nbins), np.zeros(nbins)
central_bins = pd.Series(centrals['id']).map(halo_id_to_bin)
central_cols = pd.Series(centrals['ssfr']).map(lambda x: x < red_cut)
for bin_id, red in zip(central_bins, central_cols):
if not np.isnan(bin_id):
if red:
num_actual_red_c[bin_id] += 1
else:
num_actual_blue_c[bin_id] += 1
pred_sats = test_gals[test_gals['upid'] != -1]
pred_cents = test_gals[test_gals['upid'] == -1]
num_pred_blue_s, num_pred_red_s = np.zeros(nbins), np.zeros(nbins)
pred_sat_bins = pd.Series(pred_sats['upid']).map(halo_id_to_bin)
pred_sat_cols = pd.Series(pred_sats['pred']).map(lambda x: x < red_cut)
for bin_id, red in zip(pred_sat_bins, pred_sat_cols):
if not np.isnan(bin_id):
if red:
num_pred_red_s[bin_id] += 1
else:
num_pred_blue_s[bin_id] += 1
num_pred_blue_c, num_pred_red_c = np.zeros(nbins), np.zeros(nbins)
pred_cent_bins = pd.Series(pred_cents['id']).map(halo_id_to_bin)
pred_cent_cols = pd.Series(pred_cents['pred']).map(lambda x: x < red_cut)
for bin_id, red in zip(pred_cent_bins, pred_cent_cols):
if not np.isnan(bin_id):
if red:
num_pred_red_c[bin_id] += 1
else:
num_pred_blue_c[bin_id] += 1
plt.figure(figsize=(14,14))
plt.hold(True)
plt.grid(True)
plt.subplot(221)
# comparison of combined HODS
total_occupants_actual = num_actual_blue_s + num_actual_red_s + num_actual_red_c + num_actual_blue_c
total_occupants_pred = num_pred_blue_s + num_pred_red_s + num_pred_red_c + num_pred_blue_c
p_scale = 8./7
#plt.title('Combined HOD (' + name + ')')
plt.loglog(centers, (total_occupants_actual)/num_halos, color='k', lw=4, label='input')
# Double the number of occupants because our test sample was half the catalog
plt.loglog(centers, p_scale * (total_occupants_pred)/num_halos, color='k', label='predicted', alpha = 0.6)
plt.xlabel('$M_{halo}$')
plt.ylabel('$<N_{tot}>$')
plt.xlim(1e10, 1e15)
plt.ylim(1e-2, 1e3)
plt.legend(loc='best')
plt.subplot(222)
#plt.title('HOD red v blue (' + name + ')')
# comparison of red and blues
plt.loglog(centers, (num_actual_blue_s + num_actual_blue_c)/num_halos, color=blue_col, lw=4, label='input')
plt.loglog(centers, (num_actual_red_s + num_actual_red_c)/num_halos, color=red_col, lw=4, label='input')
plt.loglog(centers, p_scale * (num_pred_blue_s + num_pred_blue_c)/num_halos, '--', color=blue_col, label='predicted', alpha=0.6)
plt.loglog(centers, p_scale * (num_pred_red_s + num_pred_red_c)/num_halos, '--', color=red_col, label='predicted', alpha=0.6)
plt.xlabel('$M_{halo}$')
plt.ylabel('$<N_{tot}>$')
plt.xlim(1e10, 1e15)
plt.ylim(1e-2, 1e3)
plt.legend(loc='best')
plt.subplot(223)
#plt.title('HOD red v blue centrals (' + name + ')')
plt.loglog(centers, (num_actual_blue_c)/num_halos, color=blue_col, lw=4, label='input')
plt.loglog(centers, (num_actual_red_c)/num_halos, color=red_col, lw=4, label='input')
plt.loglog(centers, p_scale * (num_pred_blue_c)/num_halos, '--', color=blue_col, label='predicted', alpha=0.6)
plt.loglog(centers, p_scale * (num_pred_red_c)/num_halos, '--', color=red_col, label='predicted', alpha=0.6)
#plt.xscale('symlog')
#plt.yscale('symlog')
plt.xlabel('$M_{halo}$')
plt.ylabel('$<N_{cen}>$')
plt.xlim(1e10, 1e15)
plt.ylim(1e-2, 1e3)
plt.legend(loc='best')
plt.subplot(224)
#plt.title('HOD red v blue satellites (' + name + ')')
plt.loglog(centers, (num_actual_blue_s)/num_halos, color=blue_col, lw=4, label='input')
plt.loglog(centers, (num_actual_red_s )/num_halos, color=red_col, lw=4, label='input')
plt.loglog(centers, p_scale * (num_pred_blue_s)/num_halos, '--', color=blue_col, label='predicted', alpha=0.6)
plt.loglog(centers, p_scale * (num_pred_red_s)/num_halos, '--', color=red_col, label='predicted', alpha=0.6)
plt.xlabel('$M_{halo}$')
plt.ylabel('$<N_{sat}>$')
plt.xlim(1e10, 1e15)
plt.ylim(1e-2, 1e3)
plt.legend(loc='best')
plt.tight_layout()
return
def plot_p_red(masses, ytest, y_hat, name):
"""
For every bin in mass, calculate the fraction of red vs blue galaxies.
Additionally, the distribution of color is plotted vs mvir for both the
predicted and actual ssfr's.
"""
nbins = 15
bins = np.logspace(np.log10(np.min(masses)), np.log10(np.max(masses)), nbins)
red_cut = -11.0
actual_red, = np.where(ytest < red_cut)
pred_red, = np.where(y_hat < red_cut)
actual_blue, = np.where(ytest > red_cut)
pred_blue, = np.where(y_hat > red_cut)
actual_red_counts, _ = np.histogram(masses[actual_red], bins)
actual_blue_counts, _ = np.histogram(masses[actual_blue], bins)
pred_red_counts, _ = np.histogram(masses[pred_red], bins)
pred_blue_counts, _ = np.histogram(masses[pred_blue], bins)
p_red_test = 1.0 * actual_red_counts / (actual_red_counts + actual_blue_counts)
p_red_predicted = 1.0 * pred_red_counts / (pred_red_counts + pred_blue_counts)
print "Chi square is: ", chisquare(filter_nans(p_red_predicted), filter_nans(p_red_test))
plt.hold(True)
center = (bins[:-1] + bins[1:]) / 2
plt.plot(center, p_red_test, lw=4, label='input', color='k', alpha=0.6)
plt.plot(center, p_red_predicted, '--', label='predicted', color='red', alpha=0.8)
title = 'Fraction of Quenched Galaxies {}'.format(name)
#plt.title(title)
plt.legend(loc='best')
plt.xlabel('$M_*$')
plt.gca().set_xscale("log")
plt.ylabel('$F_Q$')
plt.ylim(0,1.1)
#plt.xlim(1e10, 1e13)
plt.savefig(image_prefix + title + png)
lm = np.log10(masses)
plt.figure(2)
sns.kdeplot(lm, ytest, shade=True)
title = 'Heatmap of mstar vs ssfr (Actual) ({})'.format(name)
plt.ylim(-13,-8)
plt.title(title)
#plt.gca().set_xscale("log")
plt.figure(3)
sns.kdeplot(lm, y_hat, shade=True)
title = 'Heatmap of mstar vs ssfr (Predicted) ({})'.format(name)
plt.title(title)
#plt.gca().set_xscale("log")
plt.show()
########################################
# Miscellaneous functions
########################################
def fits_to_pandas(df):
return pd.DataFrame.from_records(df.byteswap().newbyteorder())
def fixed_power_law(x, intercept, c):
return intercept + c * (x ** -1.5)
def power_law(x, intercept, c, power=-1):
return intercept + c * x ** power
def filter_nans(arr):
return np.ma.masked_array(arr, np.isnan(arr))
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else: raise