forked from epierson9/pain-disparities
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_models.py
1658 lines (1441 loc) · 85.9 KB
/
train_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import matplotlib
import argparse
#matplotlib.use('Agg')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import seaborn as sns
import time
import random
import os
import copy
import sklearn
import pickle as pkl
from scipy.stats import pearsonr
from constants_and_util import *
import image_processing
import json
import sys
from image_processing import PytorchImagesDataset
from torch.utils.data import Dataset, DataLoader
from non_image_data_processing import NonImageData
import datetime
import cv2
#import mura_predict
import analysis
from scipy.ndimage.filters import gaussian_filter
from scipy.special import expit
from matplotlib import ticker
import gc
import json
import statsmodels.api as sm
#import my_modified_resnet
# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy('file_system')
def create_ses_weights(d, ses_col, covs, p_high_ses, use_propensity_scores):
"""
Used for training preferentially on high or low SES people. If use_propensity_scores is True, uses propensity score matching on covs.
Note: this samples from individual images, not from individual people. I think this is okay as long as we're clear about what's being done. If p_high_ses = 0 or 1, both sampling methods are equivalent. One reason to sample images rather than people is that if you use propensity score weighting, covs may change for people over time.
"""
assert p_high_ses >= 0 and p_high_ses <= 1
high_ses_idxs = (d[ses_col] == True).values
n_high_ses = high_ses_idxs.sum()
n_low_ses = len(d) - n_high_ses
assert pd.isnull(d[ses_col]).sum() == 0
n_to_sample = min(n_high_ses, n_low_ses) # want to make sure train set size doesn't change as we change p_high_ses from 0 to 1 so can't have a train set size larger than either n_high_ses or n_low_ses
n_high_ses_to_sample = int(p_high_ses * n_to_sample)
n_low_ses_to_sample = n_to_sample - n_high_ses_to_sample
all_idxs = np.arange(len(d))
high_ses_samples = np.array(random.sample(list(all_idxs[high_ses_idxs]), n_high_ses_to_sample))
low_ses_samples = np.array(random.sample(list(all_idxs[~high_ses_idxs]), n_low_ses_to_sample))
print("%i high SES samples and %i low SES samples drawn with p_high_ses=%2.3f" %
(len(high_ses_samples), len(low_ses_samples), p_high_ses))
# create weights.
weights = np.zeros(len(d))
if len(high_ses_samples) > 0:
weights[high_ses_samples] = 1.
if len(low_ses_samples) > 0:
weights[low_ses_samples] = 1.
if not use_propensity_scores:
assert covs is None
weights = weights / weights.sum()
return weights
else:
assert covs is not None
# fit probability model
propensity_model = sm.Logit.from_formula('%s ~ %s' % (ses_col, '+'.join(covs)), data=d).fit()
print("Fit propensity model")
print(propensity_model.summary())
# compute inverse propensity weights.
# "A subject's weight is equal to the inverse of the probability of receiving the treatment that the subject actually received"
# The treatment here is whether they are high SES,
# and we are matching them on the other covariates.
high_ses_propensity_scores = propensity_model.predict(d).values
high_ses_weights = 1 / high_ses_propensity_scores
low_ses_weights = 1 / (1 - high_ses_propensity_scores)
propensity_weights = np.zeros(len(d))
propensity_weights[high_ses_idxs] = high_ses_weights[high_ses_idxs]
propensity_weights[~high_ses_idxs] = low_ses_weights[~high_ses_idxs]
assert np.isnan(propensity_weights).sum() == 0
# multply indicator vector by propensity weights.
weights = weights * propensity_weights
# normalize weights so that high and low SES sum to the right things.
print(n_high_ses_to_sample, n_low_ses_to_sample)
if n_high_ses_to_sample > 0:
weights[high_ses_idxs] = n_high_ses_to_sample * weights[high_ses_idxs] / weights[high_ses_idxs].sum()
if n_low_ses_to_sample > 0:
weights[~high_ses_idxs] = n_low_ses_to_sample * weights[~high_ses_idxs] / weights[~high_ses_idxs].sum()
assert np.isnan(weights).sum() == 0
# normalize whole vector, just to keep things clean
weights = weights / weights.sum()
return weights
def reweight_to_remove_correlation_between_pain_and_ses(d, ses_col, pain_col):
"""
Robustness check: train on dataset where we've removed the correlation between pain and SES to verify that the model isn't just learning to predict SES.
"""
d = copy.deepcopy(d)
high_ses_idxs = (d[ses_col] == True).values
d['discretized_pain_score'] = analysis.cut_into_deciles(d[pain_col].values
+ .0001 * np.random.random(len(d)), # small hack to break ties
pain_col)
predict_high_ses_given_pain = sm.Logit.from_formula('%s ~ C(discretized_pain_score)' % (
ses_col), data=d).fit()
high_ses_propensity_scores = predict_high_ses_given_pain.predict(d).values
high_ses_weights = 1 / high_ses_propensity_scores
low_ses_weights = 1 / (1 - high_ses_propensity_scores)
propensity_weights = np.zeros(len(d))
propensity_weights[high_ses_idxs] = high_ses_weights[high_ses_idxs]
propensity_weights[~high_ses_idxs] = low_ses_weights[~high_ses_idxs]
propensity_weights = propensity_weights / propensity_weights.sum()
r, p = pearsonr(d[pain_col], d[ses_col])
print("Original correlation between SES and pain: %2.3f" % r)
samples = np.random.choice(range(len(d)), p=propensity_weights, size=[50000,])
r, p = pearsonr(d[pain_col].iloc[samples], d[ses_col].iloc[samples])
print("Correlation after inverse propensity weighting: %2.3f" % r)
return propensity_weights
def load_real_data_in_transfer_learning_format(batch_size,
downsample_factor_on_reload,
normalization_method,
y_col,
max_horizontal_translation,
max_vertical_translation,
seed_to_further_shuffle_train_test_val_sets,
additional_features_to_predict,
crop_to_just_the_knee=False,
show_both_knees_in_each_image=False,
weighted_ses_sampler_kwargs=None,
increase_diversity_kwargs=None,
hold_out_one_imaging_site_kwargs=None,
train_on_single_klg_kwargs=None,
remove_correlation_between_pain_and_ses_kwargs=None,
alter_train_set_size_sampler_kwargs=None,
use_very_very_small_subset=False,
blur_filter=None):
"""
Load dataset a couple images at a time using DataLoader class, as shown in pytorch dataset tutorial.
Checked.
"""
load_only_single_klg = None
if (train_on_single_klg_kwargs is not None) and ('make_train_set_smaller' in train_on_single_klg_kwargs) and train_on_single_klg_kwargs['make_train_set_smaller']:
raise Exception("Should not be using this option.")
load_only_single_klg = train_on_single_klg_kwargs['klg_to_use']
train_dataset = PytorchImagesDataset(dataset='train',
downsample_factor_on_reload=downsample_factor_on_reload,
normalization_method=normalization_method,
show_both_knees_in_each_image=show_both_knees_in_each_image,
y_col=y_col,
seed_to_further_shuffle_train_test_val_sets=seed_to_further_shuffle_train_test_val_sets,
transform='random_translation_and_then_random_horizontal_flip' if not show_both_knees_in_each_image else 'random_translation',
additional_features_to_predict=additional_features_to_predict,
max_horizontal_translation=max_horizontal_translation,
max_vertical_translation=max_vertical_translation,
use_very_very_small_subset=use_very_very_small_subset,
crop_to_just_the_knee=crop_to_just_the_knee,
load_only_single_klg=load_only_single_klg,
blur_filter=blur_filter)
if weighted_ses_sampler_kwargs is not None:
assert train_on_single_klg_kwargs is None
assert alter_train_set_size_sampler_kwargs is None
assert remove_correlation_between_pain_and_ses_kwargs is None
ses_weights = create_ses_weights(copy.deepcopy(train_dataset.non_image_data), **weighted_ses_sampler_kwargs)
print(ses_weights)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(ses_weights, len(ses_weights)) # https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3
shuffle = False
# per the pytorch documentation,
# "sampler defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False."
# Shuffling is already taken care of by the sampler. Cool.
elif hold_out_one_imaging_site_kwargs is not None:
print("Hold out one train site kwargs are")
print(hold_out_one_imaging_site_kwargs)
weights = 1.*(train_dataset.non_image_data['v00site'].values != hold_out_one_imaging_site_kwargs['site_to_remove'])
assert weights.mean() < 1
print("After removing site %s, %i/%i train datapoints remaining" % (hold_out_one_imaging_site_kwargs['site_to_remove'], int(weights.sum()), len(weights)))
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
shuffle = False
elif increase_diversity_kwargs is not None:
print("Increase diversity kwargs are")
print(increase_diversity_kwargs)
assert weighted_ses_sampler_kwargs is None
assert remove_correlation_between_pain_and_ses_kwargs is None
minority_idxs = (train_dataset.non_image_data[increase_diversity_kwargs['ses_col']].values == increase_diversity_kwargs['minority_val'])
n_minority_people = len(set(train_dataset.non_image_data.loc[minority_idxs, 'id'].values))
majority_ids = sorted(list(set(train_dataset.non_image_data.loc[~minority_idxs, 'id'].values)))
n_majority_people = len(majority_ids)
assert n_majority_people > n_minority_people
if increase_diversity_kwargs['exclude_minority_group']:
# remove all minorities.
weights = (~minority_idxs) * 1.
else:
# remove a random sample of majority people.
rng = random.Random(increase_diversity_kwargs['majority_group_seed'])
majority_ids_to_keep = set(rng.sample(majority_ids, n_majority_people - n_minority_people))
majority_idxs = train_dataset.non_image_data['id'].map(lambda x:x in majority_ids_to_keep).values
assert ((majority_idxs == 1) & (minority_idxs == 1)).sum() == 0
weights = ((minority_idxs == 1) | (majority_idxs == 1)) * 1.
print("Number of people with %s=%i in train set: %i; number in majority set: %i; total number of points with nonzero weights %i; exclude minority group %s; random seed %s" % (
increase_diversity_kwargs['ses_col'],
increase_diversity_kwargs['minority_val'],
n_minority_people,
n_majority_people,
int(weights.sum()),
increase_diversity_kwargs['exclude_minority_group'],
increase_diversity_kwargs['majority_group_seed']))
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
shuffle = False
elif remove_correlation_between_pain_and_ses_kwargs is not None:
assert train_on_single_klg_kwargs is None
assert alter_train_set_size_sampler_kwargs is None
assert weighted_ses_sampler_kwargs is None
weights = reweight_to_remove_correlation_between_pain_and_ses(copy.deepcopy(train_dataset.non_image_data),
**remove_correlation_between_pain_and_ses_kwargs)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
shuffle = False
elif train_on_single_klg_kwargs is not None:
assert weighted_ses_sampler_kwargs is None
assert remove_correlation_between_pain_and_ses_kwargs is None
sample_weights = 1.*(train_dataset.non_image_data['xrkl'].values == train_on_single_klg_kwargs['klg_to_use'])
# See note above for weighted_ses_sampler_kwargs
if 'make_train_set_smaller' in train_on_single_klg_kwargs and train_on_single_klg_kwargs['make_train_set_smaller']:
n_train_points = int(sample_weights.sum())
else:
n_train_points = len(sample_weights)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(sample_weights, n_train_points)
shuffle = False
elif alter_train_set_size_sampler_kwargs is not None:
assert train_on_single_klg_kwargs is None
assert weighted_ses_sampler_kwargs is None
assert remove_correlation_between_pain_and_ses_kwargs is None
train_set_frac = alter_train_set_size_sampler_kwargs['fraction_of_train_set_to_use']
assert train_set_frac > 0 and train_set_frac <= 1
all_train_ids = sorted(list(set(train_dataset.non_image_data['id'])))
n_train_ids_to_use = int(len(all_train_ids) * train_set_frac)
print("Total number of people in train set: %i. Taking fraction %2.3f of them (%i ids)" %
(len(all_train_ids), train_set_frac, n_train_ids_to_use))
# ensure train set ids we take are always the same (we don't want to take best result across multiple datasets).
rng = random.Random(42)
rng.shuffle(all_train_ids)
train_ids_to_use = set(all_train_ids[:n_train_ids_to_use])
# Use WeightedRandomSampler exactly in analogy to SES weights above. This will make training slightly slower, but it seems more important to make the code consistent and functional.
sample_weights = 1.*train_dataset.non_image_data['id'].map(lambda x:x in train_ids_to_use).values
# See note above for weighted_ses_sampler_kwargs
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(sample_weights, len(sample_weights))
shuffle = False
else:
train_sampler = None
shuffle = True
NUM_WORKERS_TO_USE = 8
# Note: if you are using WeightedRandomSampler for train_sampler, and only selecting a small subset of the datapoints (eg, just those with KLG=4) you may
# quickly overtrain, since at each epoch you run through the dataset many times by sampling with replacement (replacement=True by default). This doesn't appear to be a major problem at present...tried modifying it to train on a single KLG, and it didn't improve results.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=NUM_WORKERS_TO_USE, sampler=train_sampler)
val_dataset = PytorchImagesDataset(dataset='val',
downsample_factor_on_reload=downsample_factor_on_reload,
normalization_method=normalization_method,
show_both_knees_in_each_image=show_both_knees_in_each_image,
y_col=y_col,
additional_features_to_predict=additional_features_to_predict,
seed_to_further_shuffle_train_test_val_sets=seed_to_further_shuffle_train_test_val_sets,
transform=None,
crop_to_just_the_knee=crop_to_just_the_knee,
use_very_very_small_subset=use_very_very_small_subset,
blur_filter=blur_filter)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS_TO_USE)
test_dataset = PytorchImagesDataset(dataset='test',
downsample_factor_on_reload=downsample_factor_on_reload,
normalization_method=normalization_method,
show_both_knees_in_each_image=show_both_knees_in_each_image,
y_col=y_col,
additional_features_to_predict=additional_features_to_predict,
seed_to_further_shuffle_train_test_val_sets=seed_to_further_shuffle_train_test_val_sets,
transform=None,
crop_to_just_the_knee=crop_to_just_the_knee,
use_very_very_small_subset=use_very_very_small_subset,
blur_filter=blur_filter)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS_TO_USE)
dataloaders = {'train':train_dataloader, 'val':val_dataloader, 'test':test_dataloader}
if use_very_very_small_subset:
dataset_sizes = {'train':500, 'val':500, 'test':500}
else:
dataset_sizes = {'train':len(train_dataset), 'val':len(val_dataset), 'test':len(test_dataset)}
datasets = {'train':train_dataset, 'val':val_dataset, 'test':test_dataset}
return dataloaders, datasets, dataset_sizes
class TransferLearningPytorchModel():
"""
Load and fine-tune a pretrained pytorch model.
pretrained_model_name: one of the resnets or MURA.
binary_prediction: whether the prediction task is binary or continuous
conv_layers_before_end_to_unfreeze: how many conv layers from the end we want to fine-tune.
optimizer_name, optimizer_kwargs: whether eg we're using Adam or SGD.
scheduler_kwargs: how we change the learning rate.
num_epochs: how many epochs to train for.
y_col: what we're trying to predict.
n_additional_image_features_to_predict: should be 0 or 19. Used for regularization.
additional_loss_weighting: how much we weight this additional loss.
mura_initialization_path: if training one of the MURA pretrained models, path to load weights from.
"""
def __init__(self,
pretrained_model_name,
binary_prediction,
conv_layers_before_end_to_unfreeze,
optimizer_name,
optimizer_kwargs,
scheduler_kwargs,
num_epochs,
y_col,
where_to_add_klg,
fully_connected_bias_initialization=None,
n_additional_image_features_to_predict=0,
additional_loss_weighting=0,
mura_initialization_path=None):
assert where_to_add_klg is None
self.pretrained_model_name = pretrained_model_name
self.binary_prediction = binary_prediction
self.conv_layers_before_end_to_unfreeze = conv_layers_before_end_to_unfreeze
self.optimizer_name = optimizer_name
self.optimizer_kwargs = optimizer_kwargs
self.scheduler_kwargs = scheduler_kwargs
self.num_epochs = num_epochs
self.y_col = y_col
self.where_to_add_klg = where_to_add_klg
self.fully_connected_bias_initialization = fully_connected_bias_initialization
self.n_additional_image_features_to_predict = n_additional_image_features_to_predict # if we have additional features, they are just concatenated onto linear layer.
self.additional_loss_weighting = additional_loss_weighting
if self.binary_prediction:
self.metric_to_use_as_stopping_criterion = 'val_auc'
else:
self.metric_to_use_as_stopping_criterion = 'val_negative_rmse'
assert (self.n_additional_image_features_to_predict > 0) == (self.additional_loss_weighting > 0)
assert self.n_additional_image_features_to_predict in [0, 3, 19, 22]
assert pretrained_model_name in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'pretrained_mura_densenet']
if self.where_to_add_klg in ['before_layer4', 'before_layer3', 'before_layer2']:
resnet_source = my_modified_resnet # only use fancy resnet if we have to.
else:
resnet_source = models # pytorch library.
if pretrained_model_name == 'resnet18':
self.model = resnet_source.resnet18(pretrained=True)
self.finalconv_name = 'layer4' # used for loading the final embedding for CAM.
elif pretrained_model_name == 'resnet34':
self.model = resnet_source.resnet34(pretrained=True)
self.finalconv_name = 'layer4'
elif pretrained_model_name == 'resnet50':
self.model = resnet_source.resnet50(pretrained=True)
raise Exception("Not sure what final conv name is")
elif pretrained_model_name == 'resnet101':
self.model = resnet_source.resnet101(pretrained=True)
raise Exception("Not sure what final conv name is")
elif pretrained_model_name == 'resnet152':
self.model = resnet_source.resnet152(pretrained=True)
raise Exception("Not sure what final conv name is")
elif pretrained_model_name == 'pretrained_mura_densenet':
params = mura_predict.get_model_params()
print("Mura params are", params.__dict__)
assert mura_initialization_path is not None
self.model = mura_predict.load_model(mura_initialization_path,
params=params,
use_gpu=True).model_ft
else:
raise Exception("Not a valid model name!")
self.model.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = self.model.fc.in_features
if binary_prediction:
self.model.fc = nn.Linear(in_features=num_ftrs, out_features=2 + n_additional_image_features_to_predict) # reset final fully connected layer for two-class prediction. If we have additional features, include those as outputs from the fully connected layer.
self.loss_criterion = nn.CrossEntropyLoss()
assert self.fully_connected_bias_initialization is None
else:
self.model.fc = nn.Linear(in_features=num_ftrs, out_features=1 + n_additional_image_features_to_predict) # reset final fully connected layer and make it so it has a single output.
self.loss_criterion = nn.MSELoss()
if self.fully_connected_bias_initialization is not None:
# we do this for Koos pain subscore because otherwise the final layer ends up with all positive weights, and that's weird/hard to interpret.
nn.init.constant(self.model.fc.bias.data[:1], self.fully_connected_bias_initialization)
print("Bias has been initialized to")
print(self.model.fc.bias)
# loop over layers from beginning and freeze a couple. First we need to get the layers.
def is_conv_layer(name):
# If a layer is a conv layer, returns a substring which uniquely identifies it. Otherwise, returns None.
if pretrained_model_name == 'pretrained_mura_densenet':
if 'conv' in name:
return name
else:
# this logic is probably more complex than it needs to be but it works.
if name[:5] == 'layer':
sublayer_substring = '.'.join(name.split('.')[:3])
if 'conv' in sublayer_substring:
return sublayer_substring
return None
param_idx = 0
all_conv_layers = []
for name, param in self.model.named_parameters():
print("Param %i: %s" % (param_idx, name), param.data.shape)
param_idx += 1
conv_layer_substring = is_conv_layer(name)
if conv_layer_substring is not None and conv_layer_substring not in all_conv_layers:
all_conv_layers.append(conv_layer_substring)
print("All conv layers", all_conv_layers)
# now look conv_layers_before_end_to_unfreeze conv layers before the end, and unfreeze all layers after that.
start_unfreezing = False
layers_modified_for_klg = 0
assert conv_layers_before_end_to_unfreeze <= len(all_conv_layers)
if conv_layers_before_end_to_unfreeze > 0:
conv_layers_to_unfreeze = all_conv_layers[-conv_layers_before_end_to_unfreeze:]
else:
conv_layers_to_unfreeze = []
for name, param in self.model.named_parameters():
conv_layer_substring = is_conv_layer(name)
if conv_layer_substring in conv_layers_to_unfreeze:
start_unfreezing = True
if name in ['fc.weight', 'fc.bias']:
# we always unfreeze these layers.
start_unfreezing = True
if start_unfreezing:
if self.where_to_add_klg in ['before_layer4', 'before_layer3', 'before_layer2']:
layer_to_modify = self.where_to_add_klg.replace('before_', '')
if name in ['%s.0.conv1.weight' % layer_to_modify, '%s.0.downsample.0.weight' % layer_to_modify]:
layers_modified_for_klg += 1
what_to_add = .1 * torch.randn(param.data.size()[0], 5, param.data.size()[2], param.data.size()[3])
param.data = torch.cat((param.data, what_to_add), 1)
print("Param %s is UNFROZEN" % (name), param.data.shape)
else:
print("Param %s is FROZEN" % (name), param.data.shape)
param.requires_grad = False
if self.where_to_add_klg in ['before_layer4', 'before_layer3', 'before_layer2']:
assert layers_modified_for_klg == 2 # make sure we unfroze the two layers we needed to unfreeze.
if self.where_to_add_klg == 'output':
self.model.klg_fc = nn.Linear(in_features=5, out_features=1)
self.model = self.model.cuda() # move model to GPU.
# https://github.com/pytorch/pytorch/issues/679
if optimizer_name == 'sgd':
self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), **optimizer_kwargs)
elif optimizer_name == 'adam':
self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), **optimizer_kwargs)
else:
raise Exception("Not a valid optimizer")
self.lr_scheduler_type = scheduler_kwargs['lr_scheduler_type']
if self.lr_scheduler_type == 'decay':
self.scheduler = lr_scheduler.StepLR(self.optimizer,
**scheduler_kwargs['additional_kwargs'])
elif self.lr_scheduler_type == 'plateau':
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
**scheduler_kwargs['additional_kwargs'])
else:
raise Exception("invalid scheduler")
self.layer_magnitudes = {}
def print_layer_magnitudes(self, epoch):
# small helper method so we can make sure the right layers are being trained.
for name, param in self.model.named_parameters():
magnitude = np.linalg.norm(param.data.cpu())
if param not in self.layer_magnitudes:
self.layer_magnitudes[param] = magnitude
print("The magnitude of layer %s at epoch %i is %2.5f" % (name, epoch, magnitude))
else:
old_magnitude = self.layer_magnitudes[param]
delta_magnitude = magnitude - old_magnitude
print("The magnitude of layer %s at epoch %i is %2.5f (delta %2.5f from last epoch)" % (name, epoch, magnitude, delta_magnitude))
self.layer_magnitudes[param] = magnitude
def evaluate_on_dataset(self, dataloaders, dataset_sizes, phase, make_plot=False):
"""
Given a model, data, and a phase (train/val/test) runs the model on the data and, if phase=train, trains the model. Checked.
"""
print("Now we are evaluating on the %s dataset!" % phase)
assert phase in ['train', 'val', 'test']
use_gpu = torch.cuda.is_available()
if phase == 'train':
self.model.train(True) # Set model to training mode
else:
self.model.train(False) # Set model to evaluate mode
running_loss = 0.0
n_batches_loaded = 0
start_time_for_100_images = time.time()
# Iterate over data.
# keep track of all labels + outputs to compute the final metrics.
concatenated_labels = []
concatenated_outputs = []
concatenated_binarized_education_graduated_college = []
concatenated_binarized_income_at_least_50k = []
concatenated_numerical_klg = []
concatenated_site = []
if self.n_additional_image_features_to_predict > 0:
loss_additional_loss_ratios = [] # also keep track of how big the additional regularization loss is relative to the main loss.
for data in dataloaders[phase]:
#print("We reached the beginning of the loop with %i images" % n_batches_loaded)
n_batches_loaded += 1
if n_batches_loaded % 100 == 0:
print("Time taken to process 100 batches %2.3f seconds (total batches %i)" % (time.time() - start_time_for_100_images, len(dataloaders[phase])))
start_time_for_100_images = time.time()
# get the inputs
inputs = data['image']
labels = data['y']
additional_features_to_predict = data['additional_features_to_predict']
additional_features_are_not_nan = data['additional_features_are_not_nan']
one_hot_klg = data['klg'] # Note that this is a matrix (one-hot).
assert one_hot_klg.size()[1] == 5
numerical_klg = np.nonzero(np.array(one_hot_klg))[1]
assert len(numerical_klg) == len(one_hot_klg)
binarized_education_graduated_college = np.array(data['binarized_education_graduated_college'])
binarized_income_at_least_50k = np.array(data['binarized_income_at_least_50k'])
concatenated_site += list(np.array(data['site']))
concatenated_binarized_education_graduated_college += list(binarized_education_graduated_college)
concatenated_binarized_income_at_least_50k += list(binarized_income_at_least_50k)
concatenated_numerical_klg += list(numerical_klg)
# wrap them in Variable
if use_gpu:
inputs = Variable(inputs.float().cuda())
if self.n_additional_image_features_to_predict > 0:
additional_features_to_predict = Variable(additional_features_to_predict.float().cuda())
additional_features_are_not_nan = Variable(additional_features_are_not_nan.float().cuda())
one_hot_klg = Variable(one_hot_klg.float().cuda())
if self.binary_prediction:
labels = Variable(labels.long().cuda())
else:
labels = Variable(labels.float().cuda())
else:
raise Exception("Use a GPU, fool.")
# zero the parameter gradients
self.optimizer.zero_grad()
# forward
if self.where_to_add_klg in ['before_layer4', 'before_layer3', 'before_layer2']:
outputs = self.model(inputs,
additional_input_features=one_hot_klg,
where_to_add=self.where_to_add_klg)
else:
outputs = self.model(inputs)
if self.where_to_add_klg == 'output':
outputs = outputs + self.model.klg_fc(one_hot_klg)
if self.n_additional_image_features_to_predict > 0:
additional_feature_outputs = outputs[:, -self.n_additional_image_features_to_predict:]
outputs = outputs[:, :-self.n_additional_image_features_to_predict]
loss = self.loss_criterion(input=outputs, target=labels)
# basically, we only add to the additional feature loss if a feature is not NaN.
additional_feature_losses = ((additional_features_to_predict - additional_feature_outputs) ** 2) * additional_features_are_not_nan
additional_loss = additional_feature_losses.sum(dim=1).mean(dim=0)
original_loss_float = loss.data.cpu().numpy().flatten()
additional_loss_float = additional_loss.data.cpu().numpy().flatten()
loss_additional_loss_ratios.append(original_loss_float / (self.additional_loss_weighting * additional_loss_float))
loss = loss + additional_loss * self.additional_loss_weighting
else:
loss = self.loss_criterion(input=outputs, target=labels)
# keep track of everything for correlations
concatenated_labels += list(labels.data.cpu().numpy().flatten())
if self.binary_prediction:
# outputs are logits. Take softmax, class 1 prediction.
h_x = F.softmax(outputs, dim=1).data.squeeze()
concatenated_outputs += list(h_x[:, 1].cpu().numpy().flatten())
else:
concatenated_outputs += list(outputs.data.cpu().numpy().flatten())
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
self.optimizer.step()
# statistics
running_loss += loss.data[0] * inputs.size(0)
epoch_loss = running_loss / dataset_sizes[phase]
metrics_for_epoch = {}
if not self.binary_prediction:
concatenated_outputs = np.array(concatenated_outputs)
concatenated_labels = np.array(concatenated_labels)
if make_plot:
plt.figure()
plt.scatter(concatenated_outputs, concatenated_labels)
plt.xlim([0, 100])
plt.ylim([0, 100])
plt.xlabel("Yhat")
plt.ylabel("Y")
plt.show()
correlation_and_rmse = analysis.assess_performance(y=concatenated_labels, yhat=concatenated_outputs, binary_prediction=False)
assert len(concatenated_outputs) == dataset_sizes[phase]
print('%s epoch loss for %s: %2.6f; RMSE %2.6f; correlation %2.6f (n=%i)' %
(phase, self.y_col, epoch_loss, correlation_and_rmse['rmse'], correlation_and_rmse['r'], len(concatenated_labels)))
metrics_for_epoch['%s_loss' % phase] = epoch_loss
metrics_for_epoch['%s_rmse' % phase] = correlation_and_rmse['rmse']
metrics_for_epoch['%s_negative_rmse' % phase] = correlation_and_rmse['negative_rmse']
metrics_for_epoch['%s_r' % phase] = correlation_and_rmse['r']
print("Correlation between binarized_education_graduated_college and labels: %2.3f" % pearsonr(concatenated_binarized_education_graduated_college, concatenated_labels)[0])
print("Correlation between binarized_income_at_least_50k and labels: %2.3f" % pearsonr(concatenated_binarized_income_at_least_50k, concatenated_labels)[0])
# if Koos score, also compute AUC + AUPRC for the binarized versions.
if self.y_col == 'koos_pain_subscore':
assert np.allclose(concatenated_labels.max(), 100)
concatenated_binarized_labels = binarize_koos(concatenated_labels)
concatenated_scores = -concatenated_outputs # lower predictions = more likely to be positive class
binarized_auc_and_auprc = analysis.assess_performance(y=concatenated_binarized_labels, yhat=concatenated_scores, binary_prediction=True)
metrics_for_epoch['%s_binarized_auc' % phase] = binarized_auc_and_auprc['auc']
metrics_for_epoch['%s_binarized_auprc' % phase] = binarized_auc_and_auprc['auprc']
metrics_for_epoch['%s_ses_betas' % phase] = {'binarized_education_graduated_college_betas':None,
'binarized_income_at_least_50k_betas':None}
if phase == 'test':
# compute SES pain gaps for KLG >= 2.
education_pain_gaps = analysis.compare_pain_levels_for_people_geq_klg_2(yhat=np.array(concatenated_outputs),
y=np.array(concatenated_labels),
klg=np.array(concatenated_numerical_klg),
ses=np.array(concatenated_binarized_education_graduated_college),
y_col=self.y_col)
income_pain_gaps = analysis.compare_pain_levels_for_people_geq_klg_2(yhat=np.array(concatenated_outputs),
y=np.array(concatenated_labels),
klg=np.array(concatenated_numerical_klg),
ses=np.array(concatenated_binarized_income_at_least_50k),
y_col=self.y_col)
metrics_for_epoch['%s_pain_gaps_klg_geq_2' % phase] = {'binarized_education_graduated_college':education_pain_gaps,
'binarized_income_at_least_50k':income_pain_gaps}
if phase == 'test' or phase == 'val':
# Stratify test performance by KLG.
metrics_for_epoch['stratified_by_klg'] = {}
for klg_grade_to_use in range(5):
klg_idxs = np.array(concatenated_numerical_klg) == klg_grade_to_use
metrics_for_epoch['stratified_by_klg'][klg_grade_to_use] = analysis.assess_performance(
y=np.array(concatenated_labels)[klg_idxs], yhat=np.array(concatenated_outputs)[klg_idxs], binary_prediction=False)
# Stratify performance excluding one site at a time.
metrics_for_epoch['stratified_by_site'] = {}
concatenated_site = np.array(concatenated_site)
for site_val in sorted(list(set(concatenated_site))):
exclude_site_idxs = concatenated_site != site_val
metrics_for_epoch['stratified_by_site']['every_site_but_%s' % site_val] = analysis.assess_performance(
y=np.array(concatenated_labels)[exclude_site_idxs], yhat=np.array(concatenated_outputs)[exclude_site_idxs], binary_prediction=False)
else:
metrics_for_epoch['%s_binarized_auc' % phase] = None
metrics_for_epoch['%s_binarized_auprc' % phase] = None
if self.n_additional_image_features_to_predict == 0:
assert np.allclose(np.sqrt(epoch_loss), correlation_and_rmse['rmse'])
else:
concatenated_labels = np.array(concatenated_labels)
concatenated_outputs = np.array(concatenated_outputs)
auc_and_auprc = analysis.assess_performance(y=concatenated_labels, yhat=concatenated_outputs, binary_prediction=True)
metrics_for_epoch['%s_loss' % phase] = epoch_loss
metrics_for_epoch['%s_auc' % phase] = auc_and_auprc['auc']
metrics_for_epoch['%s_auprc' % phase] = auc_and_auprc['auprc']
print("%s AUC: %2.6f; AUPRC: %2.6f; loss: %2.6f" % (phase, auc_and_auprc['auc'], auc_and_auprc['auprc'], epoch_loss))
if self.n_additional_image_features_to_predict > 0:
print("Loss divided by additional loss is %2.3f (median ratio across batches)" % np.median(loss_additional_loss_ratios))
metrics_for_epoch['%s_yhat' % phase] = concatenated_outputs
metrics_for_epoch['%s_y' % phase] = concatenated_labels
return metrics_for_epoch
def train(self, dataloaders, dataset_sizes):
"""
trains the model. dataloaders + dataset sizes should have keys train, val, and test. Checked.
"""
since = time.time()
best_model_wts = copy.deepcopy(self.model.state_dict())
best_metric_val = -np.inf
all_metrics = {}
for epoch in range(self.num_epochs):
epoch_t0 = time.time()
print('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
print('-' * 10)
metrics_for_epoch = {}
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
metrics_for_phase = self.evaluate_on_dataset(dataloaders, dataset_sizes, phase)
# Change the learning rate.
if phase == 'val':
if self.lr_scheduler_type == 'decay':
self.scheduler.step()
elif self.lr_scheduler_type == 'plateau':
self.scheduler.step(
metrics_for_phase[self.metric_to_use_as_stopping_criterion])
else:
raise Exception("Not a valid scheduler type")
print("Current learning rate after epoch %i is" % epoch)
# https://github.com/pytorch/pytorch/issues/2829 get learning rate.
for param_group in self.optimizer.param_groups:
print(param_group['lr'])
# print(self.optimizer.state_dict())
metrics_for_epoch.update(metrics_for_phase)
# deep copy the model if the validation performance is better than what we've seen so far.
if phase == 'val' and metrics_for_phase[self.metric_to_use_as_stopping_criterion] > best_metric_val:
best_metric_val = metrics_for_phase[self.metric_to_use_as_stopping_criterion]
best_model_wts = copy.deepcopy(self.model.state_dict())
all_metrics[epoch] = metrics_for_epoch
print("\n\n***\nPrinting layer magnitudes")
self.print_layer_magnitudes(epoch)
if self.where_to_add_klg == 'output':
print("KLG weights are")
print(self.model.klg_fc.weight)
print("Total seconds taken for epoch: %2.3f" % (time.time() - epoch_t0))
all_metrics['final_results'] = metrics_for_epoch
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
# load best model weights
self.model.load_state_dict(best_model_wts)
self.model.train(False) # Set model to evaluate mode
self.state_dict = best_model_wts
# evaluate on test set.
all_metrics['total_seconds_to_train'] = time_elapsed
all_metrics['test_set_results'] = self.evaluate_on_dataset(dataloaders, dataset_sizes, 'test')
return all_metrics
def get_fully_connected_layer(self, class_idx=None):
if not self.binary_prediction:
assert class_idx is None
for name, param in self.model.named_parameters():
if name == 'fc.weight':
return param.data[0, :].cpu().numpy().flatten() # we need to take zero-th index in case we have extra features.
raise Exception("No weight vector found")
else:
assert class_idx is not None
for name, param in self.model.named_parameters():
if name == 'fc.weight':
return param.data.cpu().numpy()[class_idx, :].flatten()
raise Exception("No weight vector found")
def stratify_results_by_ses(y, yhat, high_ses_idxs, binary_prediction):
"""
Report performance stratified by low and high SES.
"""
assert len(y) == len(yhat)
assert len(yhat) == len(high_ses_idxs)
low_ses_results = analysis.assess_performance(y=y[~high_ses_idxs],
yhat=yhat[~high_ses_idxs],
binary_prediction=binary_prediction)
high_ses_results = analysis.assess_performance(y=y[high_ses_idxs],
yhat=yhat[high_ses_idxs],
binary_prediction=binary_prediction)
combined_results = {}
for k in low_ses_results:
combined_results['low_ses_%s' % k] = low_ses_results[k]
for k in high_ses_results:
combined_results['high_ses_%s' % k] = high_ses_results[k]
return combined_results
def train_one_model(experiment_to_run):
"""
Main method used for training one model.
experiment_to_run specifies our experimental condition.
"""
timestring = str(datetime.datetime.now()).replace(' ', '_').replace(':', '_').replace('.', '_').replace('-', '_')
# load data.
if experiment_to_run == 'train_random_model':
dataset_kwargs, model_kwargs = generate_random_config()
elif experiment_to_run == 'predict_klg':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
# downweight additional loss by factor of roughly (std(koos_pain_score) / std(xrkl))^2
# This is approximately 200.
model_kwargs['additional_loss_weighting'] = model_kwargs['additional_loss_weighting'] / 200.
model_kwargs['y_col'] = 'xrkl'
dataset_kwargs['y_col'] = 'xrkl'
elif experiment_to_run == 'train_on_single_klg':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['train_on_single_klg_kwargs'] = {'klg_to_use':random.choice([0, 1, 4]), 'make_train_set_smaller':False} # 1, 2, 3
#model_kwargs['num_epochs'] = random.choice([25, 35, 50]) # 15
#model_kwargs["scheduler_kwargs"]["additional_kwargs"]["factor"] = random.choice([0.5, 0.75, 0.9])
#model_kwargs["scheduler_kwargs"]["additional_kwargs"]['patience'] = random.choice([1, 2])
elif experiment_to_run == 'predict_residual':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
# don't downweight additional loss because residual is approximately the same scale.
model_kwargs['y_col'] = 'koos_pain_subscore_residual'
dataset_kwargs['y_col'] = 'koos_pain_subscore_residual'
elif experiment_to_run == 'train_best_model_binary':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('binarized_koos_pain_subscore')
elif experiment_to_run == 'train_best_model_continuous':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
elif experiment_to_run == 'increase_diversity':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
ses_col = random.choice(['binarized_income_at_least_50k', 'binarized_education_graduated_college'])
n_seeds_to_fit = 5
if ses_col == 'race_black':
minority_val = 1
else:
minority_val = 0
if random.random() < 1./(n_seeds_to_fit + 1.):
exclude_minority_group = True
else:
exclude_minority_group = False
dataset_kwargs['increase_diversity_kwargs'] = {'ses_col':ses_col, 'minority_val':minority_val, 'exclude_minority_group':exclude_minority_group}
if not dataset_kwargs['increase_diversity_kwargs']['exclude_minority_group']:
dataset_kwargs['increase_diversity_kwargs']['majority_group_seed'] = random.choice(range(n_seeds_to_fit))
else:
dataset_kwargs['increase_diversity_kwargs']['majority_group_seed'] = None
elif experiment_to_run == 'change_ses_weighting':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
ses_col = 'race_black'
if ses_col == 'binarized_income_at_least_50k':
p = random.choice([0, 1])
elif ses_col == 'binarized_education_graduated_college':
p = random.choice([0, 1])
elif ses_col == 'race_black':
p = 0 # remove minority group; can't remove majority group because minority is too small.
else:
raise Exception("invalid ses col")
dataset_kwargs['weighted_ses_sampler_kwargs'] = {'ses_col':ses_col,
'covs':None,#DEMOGRAPHIC_CONTROLS + ['C(xrkl)'],
'p_high_ses':p,
'use_propensity_scores':False}
elif experiment_to_run == 'change_ses_weighting_with_propensity_matching':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['weighted_ses_sampler_kwargs'] = {'ses_col':'binarized_income_at_least_50k',
'covs':AGE_RACE_SEX_SITE + ['C(xrkl)'],
'p_high_ses':random.choice([0., 0.5, 1.]),#random.choice([.1, .25, .5, .75, .9]),
'use_propensity_scores':True}
elif experiment_to_run == 'remove_correlation_between_pain_and_ses':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['remove_correlation_between_pain_and_ses_kwargs'] = {'ses_col':random.choice(['binarized_income_at_least_50k',
'binarized_education_graduated_college', 'race_black']),
'pain_col':'koos_pain_subscore'}
elif experiment_to_run == 'train_on_both_knees':
dataset_kwargs, model_kwargs = generate_random_config()
dataset_kwargs['show_both_knees_in_each_image'] = True
elif experiment_to_run == 'alter_train_set_size':
dataset_kwargs, model_kwargs = generate_random_config()
dataset_kwargs['alter_train_set_size_sampler_kwargs'] = {'fraction_of_train_set_to_use':
random.choice([.1, .2, .5, .75, .9, 1])}
elif experiment_to_run == 'different_random_seeds':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['seed_to_further_shuffle_train_test_val_sets'] = random.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
elif experiment_to_run == 'blur_image':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['blur_filter'] = random.choice([1/2., 1/4., 1/8., 1/16., 1/32., 1/64.])#random.choice([0, 1, 2, 5, 8, 10, 15, 20, 50])
elif experiment_to_run == 'hold_out_one_imaging_site':
dataset_kwargs, model_kwargs = generate_config_that_performs_well('koos_pain_subscore')
dataset_kwargs['hold_out_one_imaging_site_kwargs'] = {'site_to_remove':random.choice(['A', 'B', 'C', 'D', 'E'])}
else:
raise Exception("not a valid experiment")
assert model_kwargs['y_col'] == dataset_kwargs['y_col']
print('dataset kwargs', json.dumps(dataset_kwargs, indent=4))
print('model kwargs', json.dumps(model_kwargs, indent=4))
dataloaders, datasets, dataset_sizes = load_real_data_in_transfer_learning_format(**dataset_kwargs)
# actually train model.
pytorch_model = TransferLearningPytorchModel(**model_kwargs)
all_training_results = pytorch_model.train(dataloaders=dataloaders, dataset_sizes=dataset_sizes)
# stratify test performance by SES.
high_ses_idxs = copy.deepcopy(datasets['test'].non_image_data['binarized_income_at_least_50k'] == True).values
y = copy.deepcopy(all_training_results['test_set_results']['test_y'])
yhat = copy.deepcopy(all_training_results['test_set_results']['test_yhat'])
binary_prediction = model_kwargs['binary_prediction']
ses_stratified_results = stratify_results_by_ses(y=y,
yhat=yhat,
high_ses_idxs=high_ses_idxs,
binary_prediction=binary_prediction)
all_training_results['test_set_results'].update(ses_stratified_results)
# Stratify test performance by KLG.
all_training_results['test_set_results']['stratified_by_klg'] = {}
if experiment_to_run != 'predict_klg':
for klg_grade_to_use in range(5):
klg_idxs = copy.deepcopy(datasets['test'].non_image_data['xrkl'] == klg_grade_to_use).values
all_training_results['test_set_results']['stratified_by_klg'][klg_grade_to_use] = analysis.assess_performance(y=y[klg_idxs], yhat=yhat[klg_idxs], binary_prediction=False)
print("Test results for KLG=%i with %i points are" % (klg_grade_to_use, klg_idxs.sum()))
print(all_training_results['test_set_results']['stratified_by_klg'][klg_grade_to_use])
# save config.
print("Saving weights, config, and results at timestring %s" % timestring)
config = {'dataset_kwargs':dataset_kwargs, 'model_kwargs':model_kwargs, 'experiment_to_run':experiment_to_run}
config_path = os.path.join(FITTED_MODEL_DIR, 'configs', '%s_config.pkl' % timestring)
pickle.dump(config, open(config_path, 'wb'))
# save results
results_path = os.path.join(FITTED_MODEL_DIR, 'results', '%s_results.pkl' % timestring)
pickle.dump(all_training_results, open(results_path, 'wb'))
# save model weights.
weights_path = os.path.join(FITTED_MODEL_DIR, 'model_weights', '%s_model_weights.pth' % timestring)
torch.save(pytorch_model.model.state_dict(), weights_path)
def generate_random_config():
"""
Generate a random config that specifies the dataset + model configuration.
Checked.
"""
#print("Random state at the beginning is", random.getstate())
y_col = random.choice(['koos_pain_subscore'])#, 'binarized_koos_pain_subscore'])#random.choice(['binarized_education_graduated_college', 'binarized_income_at_least_50k', 'koos_pain_subscore_residual'])
if y_col in ['binarized_koos_pain_subscore', 'binarized_education_graduated_college', 'binarized_income_at_least_50k']:
binary_prediction = True
elif y_col in ['koos_pain_subscore', 'xrkl', 'koos_pain_subscore_residual']:
binary_prediction = False
else:
raise Exception("Not a valid y column")
crop_to_just_the_knee = False
if not crop_to_just_the_knee:
show_both_knees_in_each_image = random.choice([True])# Seems to perform slightly better and also more interpretable., False])
if show_both_knees_in_each_image:
max_horizontal_translation = random.choice([0, .1, .2])
else:
max_horizontal_translation = random.choice([0, .25, .5, .75])
else:
show_both_knees_in_each_image = False
max_horizontal_translation = random.choice([0, .1, .2])
dataset_kwargs = {
'y_col':y_col,
'max_horizontal_translation': max_horizontal_translation,
'max_vertical_translation':random.choice([0, .1, .2]), # Tried 0.5, but seems like a little too much, and doesn't improve performance.
'use_very_very_small_subset':False,
'crop_to_just_the_knee':crop_to_just_the_knee,
'show_both_knees_in_each_image':show_both_knees_in_each_image,
'downsample_factor_on_reload':random.choice([None]) if not crop_to_just_the_knee else random.choice([0.5, None]),#random.choice([None, 0.7, 0.5, 0.3]), # Originally images were 512 x 512 and downsample factors were [None, 0.7]. Now images are 1024 by 1024.
'weighted_ses_sampler_kwargs':None,