forked from ChristianGaser/BrainAGE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BA_gpr_ui.m
2006 lines (1730 loc) · 69.8 KB
/
BA_gpr_ui.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
function [BrainAGE, BrainAGE_unsorted, BrainAGE_all, D, age] = BA_gpr_ui(D)
% [BrainAGE, BrainAGE_unsorted, BrainAGE_all, D] = BA_gpr_ui(D)
% User interface for BrainAGE estimation in BA_gpr.m
%
% D.data - test sample for BrainAGE estimation
% D.seg_array - segmentations
% {'rp1'} use GM
% {'rp2'} use WM
% {'rp1,'rp2'} use GM or WM
% {'rp1+rp2'} use both GM+WM
% D.res_array - spatial resolution of data as char or cell (can be a cell array to try different values: {'4','8'})
% D.smooth_array - smoothing size as char or cell (can be a cell array to try different values: {'s4','s8'})
% D.train_array - cell array name(s) of training samples
% Healthy adults:
% IXI547 - IXI
% OASIS316 - OASIS
% OASIS3_1752 - OASIS3 (all time points of 549 subjects)
% OASIS3_549 - OASIS3 (only last time point)
% CamCan651 - CamCan
% SALD494 - SALD
% NKIe629 - NKIe (minimum age 6y)
% NKIe516 - NKIe (minimum age 18y)
% ADNI231Normal - ADNI Normal-sc-1.5T
%
% Children data:
% NIH394 - NIH objective 1
% NIH879 - NIH release 4.0
% NIH755 - NIH release 5.0
%
% Children + adults:
% fCONN772 - fcon-1000 (8-85 years)
%
% D.relnumber - CAT12 release (e.g. '_r1840')
% D.age_range - age range of training data
% [0 Inf] use all data (default)
% [50 80] use age range of 50..80
% if not defined use min/max of age of test data
% D.ind_adjust - define indices for adjusting data according to trend defined with D.trend_degree
% usually this is the control group that is used in order to adjust the data
% D.site_adjust - If data are acquired at different sites (e.g. using different scanners or sequences) the trend
% correction should be estimated and applied for each site seperately. A vector with coding of the scanners is required.
% If this parameter is empty this ensures that no site-specific adjustment is made even for multiple
% training data that are combined with a "+".
% D.comcat - If data are acquired at different sites (e.g. using different scanners or sequences) we can
% harmonize data using ComCAT. A vector with coding of the scanners is required or if ComCat will be only used to correct between
% the different training samples and the test sample a single value can be also used and the vector defining the samples will
% be automatically build (EXPERIMENTAL!).
% D.ind_groups - define indices of groups, if D.ind_adjust is not given, then the first group of this index will
% be used for adjusting data according to trend defined with D.trend_degree
% D.ind_train - define indices of subjects used for training (e.g. limit the training to male subjects only)
% D.trend_degree - estimate trend with defined order using healthy controls and apply it to all data (set to -1 for skipping trend correction)
% D.trend_method - use different methods for estimating trend:
% 0 skip trend correction (set trend_degree to -1)
% 1 use BrainAGE for trend correction (default)
% 2 use predicted age for trend correction (as used in Cole et al. 2018)
% D.trend_ensemble - apply trend correction for each ensemble separately before bagging/stacking
% 0 skip trend correction for each ensemble (default)
% 1 apply trend correction for each ensemble
% D.hyperparam - GPR hyperparameters (.mean and .lik)
% D.RVR - use old RVR method
% D.PCA - apply PCA as feature reduction (default=1), values > 1 define number of PCA components
% D.PCA_method - method for PCA
% 'eig' Eigenvalue Decomposition of the covariance matrix (faster but less accurate, for compatibiliy)
% 'svd' Singular Value Decomposition of X (the default)
% D.k_fold - k-fold validation if training and test sample are the same or only one is defined (10-fold as default)
% Common approach for k-fold is to divide the sample into k parts and to use the
% larger part (n-n/k) for training and the remaining part (n/k) for testing.
% D.k_fold_TPs - definition of time points for k-fold validation to ensure that multiple time points of one subject are not mixed
% between test and training data (only necessary to define for longitudinal data and k-fold validation)
% D.k_fold_reps - Number of repeated k-fold cross-validation
% D.k_fold_rand - As default the age values for the training sample is sorted and every k-th data is selected for training to minimize age
% differences between training and test data. With k_fold_rand you can set the seed for the random number generator.
% D.p_dropout - Dropout probability to randomly exclude voxels/data points to implement an uncertainty-aware approach using a
% Monte-Carlo Dropout during inference. That means that during testing, voxels are randomly dropped out according
% to the dropout probabilities. This process is repeated multiple times, and each time, the model produces
% a different output. By averaging these outputs, we can obtain a more robust prediction and estimate the model's
% uncertainty in its predictions. A meaningful dropout probability is 0.1, which means that 10% of the data points
% are excluded. The default is 0.
% D.ensemble - ensemble method to combine different models
% 0 - Majority voting: use model with lowest MAE
% 1 - Weighted GLM average: use GLM estimation to estimate model weights to minimize MAE
% 2 - Average: use mean to weight different models
% 3 - GLM: use GLM estimation to maximize variance to a group or a regression parameter (EXPERIMENTAL!)
% 4 - Stacking: use GPR to combine models (EXPERIMENTAL!, only works with k_fold validation)
% 5 - Weighted Average: (average models with weighting w.r.t. squared MAE) (default)
% 6 - GLM: use GLM estimation for different tissues (i.e. GM/WM) to maximize variance to a group or a regression parameter (EXPERIMENTAL!)
% In contrast to ensemble model 3, we here only use the mean tissue values and not all models to estimate weights
% D.contrast - define contrast to maximize group differences (use only if D.ensemble is 3 or 6) (e.g. [1 -1])
% D.contrast can be also a vector which is used to maximize variance between BrainAGE and this parameter.
% D.dir - directory for databases and code
% D.verbose - verbose level
% 0 - suppress long outputs
% 1 - print meaningful outputs (default)
% 2 - print long outputs
% D.threshold_std - all data with a standard deviation > D.threshold_std of mean covariance are excluded (after covarying out effects of age)
% meaningful values are 1,2 or Inf
% D.eqdist - options for age and sex equalization between test and train
% D.eqdist.weight - vector of size 2 that allows to weight the cost function for age and sex equalization
% D.eqdist.range - matrix 2 x 2 which defines the age range and sex range for equalization
% D.eqdist.tol - vector of size 2 that defines tolerance between mean value of
% age_test and age_train and male_test and male_train
% D.eqdist.debug - print debug info if set (default 0)
% D.corr - additionally define parameter that can be correlated to BrainAGE if only one group is given
% D.define_cov - optionally define continous parameter that should be used instead of age for more general use
% of GPR not only limited to BrainAGE
% D.style - plot-style: 1: old style with vertical violin-plot; 2: new style with horizontal density plot
% D.groupcolor - matrix with (group)-bar-color(s), use jet(numel(data)) or other color functions (nejm by default)
% D.normalize_BA - normalize BA values w.r.t. MAE to make BA less dependent from training sample (i.e. size) and scale
% it to MAE of 5
% D.nuisance - additionally define nuisance parameter for covarying out (e.g. gender)
% D.parcellation - use parcellation into lobes to additionally estimate local BrainAGE values:
% https://figshare.com/articles/dataset/Brain_Lobes_Atlas/971058
% 0 - estimate global BrainAGE
% 1 - estimate local BrainAGE for different lobes for both hemispheres
% D.spiderplot.func - show spider (radar) plot either with mean or median values (only valid if D.parcellation is used):
% 'median' - use median values
% 'mean' - use mean values (default)
% D.spiderplot.range- range for spiderplot (default automatically find range)
%
% Parameter search
% ---------------
% Some selected parameters can be also defined as ranges to try different parameter settings.
% Examples:
% D.trend_degree = 1;
% D.threshold_std = [Inf];
% D.age_range = [20 50];
% D.res_array = {'4','8'};
% D.smooth_array = {'s4','s8'};
% D.train_array = {'IXI547','OASIS316+ADNI231Normal'};
% D.ensemble = 1; % minimize MAE
% D.data = 'Your_Sample';
%
% Output
% -------
% BrainAGE - BrainAGE values sorted by group definitions in D.ind_groups
% BrainAGE_unsorted - unsorted (originally ordered) BrainAGE values
% BrainAGE_all - array of BrainAGE values for all models sorted by group definitions in D.ind_groups
% ______________________________________________________________________
%
% Christian Gaser
% Structural Brain Mapping Group (https://neuro-jena.github.io)
% Departments of Neurology and Psychiatry
% Jena University Hospital
% ______________________________________________________________________
% $Id$
%#ok<*AGROW>
global min_hyperparam %#ok<GVMIS>
% add cat12 path if not already done
if ~exist('cat_stat_polynomial','file')
addpath(fullfile(spm('dir'),'toolbox','cat12'));
end
% use normalized BA as default and scale it to MAE of 5
if ~isfield(D,'normalize_BA')
D.normalize_BA = 0;
else
if D.normalize_BA == 1
D.normalize_BA = 5;
end
end
if ~isfield(D,'spiderplot') || (isfield(D,'spiderplot') && ~isfield(D.spiderplot,'func'))
D.spiderplot.func = 'mean';
end
if ~isfield(D,'trend_degree')
D.trend_degree = 1;
end
if ~isfield(D,'trend_method')
D.trend_method = 1;
end
if ~isfield(D,'trend_ensemble')
D.trend_ensemble = 0;
end
if ~isfield(D,'ensemble')
D.ensemble = 5;
end
if D.ensemble < 0 && ~exist('fmincon','var')
fprintf('In order to use non-linear optimization you need the Optimization Toolbox.\n');
return
end
if ~isfield(D,'age_range')
D.age_range = [0 Inf];
end
if ~isfield(D,'hyperparam')
D.hyperparam = struct('mean', 100, 'lik', -1);
end
if ~isfield(D,'style')
style = 2;
else
style = D.style;
end
if ~isfield(D,'threshold_std')
D.threshold_std = Inf;
end
if ~isfield(D,'PCA')
D.PCA = 1;
end
if ~isfield(D,'RVR')
D.RVR = 0;
end
if ~isfield(D,'PCA_method')
D.PCA_method = 'svd';
end
if ~isfield(D,'parcellation')
D.parcellation = 0;
end
if D.trend_method > 1 && D.trend_degree > 1
D.trend_degree = 1;
fprintf('Only use linear trend correction for method that uses predicted age for obtaining trend correction.\n');
end
if ~D.trend_method
if D.trend_degree > -1, fprintf('Disable trend correction.\n'); end
D.trend_degree = -1;
end
% this is just for compatbility with older scripts
if isfield(D,'seg') && ~isfield(D,'seg_array')
D.seg_array = D.seg;
D = rmfield(D,'seg');
end
% this is just for compatbility with older scripts
if isfield(D,'training_sample') && ~isfield(D,'train_array')
if numel(D.training_sample) > 1
error('Please use new syntax and use D.train_array instead.');
end
D.train_array = D.training_sample;
D = rmfield(D,'training_sample');
end
% convert to cell if necessary
if ~iscell(D.seg_array)
D.seg_array = cellstr(D.seg_array);
end
% array with different smoothing sizes
if ~iscell(D.smooth_array)
D.smooth_array = cellstr(D.smooth_array);
end
% array with different spatial resolutions
if ~iscell(D.res_array)
D.res_array = cellstr(D.res_array);
end
% verbose level
if ~isfield(D,'verbose')
D.verbose = 1;
end
% fill the missing field if neccessary
if ~isfield(D,'data')
D.data = D.train_array{1};
end
% set default for k_fold_reps
if ~isfield(D,'k_fold_reps')
D.k_fold_reps = 1;
end
if isfield(D,'k_fold_rand') && D.k_fold_reps > 1
error('D.k_fold_rand cannot be used together with D.k_fold_reps because repeated k-fold would always use the same random numbers without variations.');
end
% set default for droput probability
if ~isfield(D,'p_dropout')
D.p_dropout = 0;
end
if iscell(D.data)
D.data = char(D.data);
end
% consider old syntax and name
if isfield(D,'n_fold') && ~isfield(D,'k_fold')
D.k_fold = D.n_fold;
end
% if comcat is defined and set to 0 then remove this field
if isfield(D,'comcat') && isscalar(D.comcat) && D.comcat == 0
D = rmfield(D,'comcat');
end
region_names = {'R Frontal','R Parietal','R Occipital','R Temporal','R Subcortical/Cerebellum',...
'L Frontal','L Parietal','L Occipital','L Temporal','L Subcortical/Cerebellum'};
ensemble_str = {'Majority Voting (model with lowest MAE)',...
'Weighted GLM Average (GLM for weighting models to minimize MAE)',...
'Average of all models',...
'GLM for weighting all models to maximize variance w.r.t. contrast vector',...
'Stacking: GPR for combining models',...
'Weighted Average: (average models with weighting w.r.t. squared MAE)',...
'GLM for weighting tissue models (i.e. GM/WM) to maximize variance w.r.t. contrast vector'};
if isfield(D,'define_cov')
if ~isempty(strfind(D.data,'+'))
error('D.define_cov cannot be used for multiple training data.');
end
if isfield(D,'n_data')
if numel(D.define_cov) ~= D.n_data
error('D.define_cov has different size (%d) than data (%d).',numel(D.define_cov),D.n_data);
end
else
% assume that D.define_cov is correctly defined and we can obtain
% data size form that
D.n_data = numel(D.define_cov);
age = D.define_cov;
end
if D.trend_degree > -1
D.trend_degree = -1;
fprintf('Disable trend correction because this cannot be used for other non-BrainAGE parameters\n');
end
end
% load first data set to get data size
if ~isfield(D,'n_data')
ind_plus = strfind(D.seg_array{1},'+');
if ~isempty(ind_plus)
seg_array = D.seg_array{1}(1:ind_plus-1);
else
seg_array = D.seg_array{1};
end
% find potential "+" indicating to combine training sample
ind_plus = strfind(D.data,'+');
if ~isempty(ind_plus)
% D.comcat can be also just defined as single value that indicates
% that the comcat-vector that defines the different samples will be
% automatically build
if isfield(D,'comcat') && isscalar(D.comcat) && D.comcat == 1
D_comcat = [];
end
age0 = [];
ind_plus = [0 ind_plus length(D.data)+1];
n_train = numel(ind_plus)-1;
for l = 1:n_train
load([D.smooth_array{1} seg_array '_' D.res_array{1} 'mm_' D.data(ind_plus(l)+1:ind_plus(l+1)-1) D.relnumber],'age');
age0 = [age0; age];
if isfield(D,'comcat') && isscalar(D.comcat) && D.comcat == 1
D_comcat = [D_comcat; l*ones(size(age))];
end
end
if isfield(D,'comcat') && isscalar(D.comcat) && D.comcat == 1
D.comcat = D_comcat;
end
age = age0;
D.n_data = numel(age);
else
load([D.smooth_array{1} seg_array '_' D.res_array{1} 'mm_' D.data D.relnumber],'age');
D.n_data = numel(age);
if isfield(D,'comcat') && isscalar(D.comcat) && D.comcat == 1 && strcmp(D.train_array{1},D.data)
D.comcat = ones(size(age));
end
end
end
if ~isfield(D,'ind_groups')
D.ind_groups{1} = 1:D.n_data;
end
if ~isfield(D,'site_adjust')
D.site_adjust = ones(D.n_data,1);
end
if ~isfield(D,'ind_adjust')
D.ind_adjust = D.ind_groups{1};
end
if isfield(D,'run_kfold') && ~isfield(D,'train_array')
D.train_array = {D.data};
end
if isfield(D,'weighting') && ~isfield(D,'ensemble')
D.ensemble = D.weighting;
fprintf('This option is deprecated. Use the option ''ensemble'' instead.\n');
return
end
% check whether contrast was defined for ensemble=3
if isfield(D,'ensemble')
if numel(D.ensemble) > 1 && ~isfield(D,'k_fold')
error('Multiple ensembles are only allowed for k-fold validation.');
end
if (abs(D.ensemble(1)) == 3 || abs(D.ensemble(1)) == 6) && ~isfield(D,'contrast')
error('D.contrast has to be defined.');
end
end
% print some parameters
if ~isfield(D,'run_kfold')
res = []; seg = []; smo = [];
for i = 1:numel(D.res_array)
res = [res D.res_array{i} ' '];
end
for i = 1:numel(D.smooth_array)
smo = [smo D.smooth_array{i} ' '];
end
for i = 1:numel(D.seg_array)
seg = [seg D.seg_array{i} ' '];
end
fprintf('--------------------------------------------------------------\n');
fprintf('Data: \t%s\nResolution: \t%s\nSmoothing: \t%s\nSegmentation: \t%s\nThreshold-Std:\t%d\n',...
[D.data D.relnumber],res,smo,seg,D.threshold_std);
if isfield(D,'train_array')
tra = [];
for i = 1:numel(D.train_array)
tra = [tra D.train_array{i} ' '];
end
fprintf('Training-Data:\t%s\n',tra);
end
if isfield(D,'ensemble')
fprintf('Model-Weight: \t%d\n',D.ensemble);
end
if isfield(D,'k_fold')
fprintf('k-Fold: \t%d\n',D.k_fold);
end
if D.p_dropout
fprintf('Prob-Dropout: \t%d\n',D.p_dropout);
end
if D.RVR
fprintf('RVR: \t%d\n',D.RVR);
end
fprintf('PCA: \t%d (method: %s)\n',D.PCA,D.PCA_method);
fprintf('Trend method: \t%d\n',D.trend_method);
fprintf('Age-Range: \t%g-%g\n',D.age_range(1),D.age_range(2));
fprintf('--------------------------------------------------------------\n');
if isfield(D,'parcellation') && D.parcellation
fprintf('Estimate local BrainAGE with parcellation into lobes.\n');
end
end
% run k-fold validation if no data field is given or validation with k_fold is defined
if ((~isfield(D,'data') || ~isfield(D,'train_array')) || isfield(D,'k_fold')) && ~isfield(D,'run_kfold')
if isfield(D,'ensemble') && (abs(D.ensemble(1)) == 3 || abs(D.ensemble(1)) == 6) && numel(D.ind_groups) < 1
error('Ensemble model 3 or 6 cannot be used within k-fold validation with more than one group.');
end
ind_adjust = D.ind_adjust;
% use 10-fold as default
if ~isfield(D,'k_fold')
D.k_fold = 10;
end
ind_all = [];
age_all = [];
BA_all = [];
% ensure that this field is always defined and set to ones by default
if ~isfield(D,'k_fold_TPs')
D.k_fold_TPs = ones(D.n_data,1);
end
% number of time points
n_TPs = max(D.k_fold_TPs);
% for longitudinal data only
if n_TPs > 1
% find order of time point definition
% offset_TPs = 1 -> alternating order (e.g. 1 2 1 2 1 2)
% offset_TPs > 1 -> consecutive order (e.g. 1 1 1 2 2 2)
offset_TPs = find(diff(D.k_fold_TPs), 1 );
ns = [];
for i = 1:n_TPs
ns = [ns sum(D.k_fold_TPs==i)];
end
if any(diff(ns))
error('Time points should all have same size if you apply k-fold to longitudinal data.');
end
if offset_TPs == 1
fprintf('Longitudinal data with %d time points with alternating order found.\n',n_TPs);
else
fprintf('Longitudinal data with %d time points with consecutive order found.\n',n_TPs);
end
% sort only data for TP1
[~, ind_age] = sort(age(D.k_fold_TPs==1));
if offset_TPs == 1
ind_age = ind_age * n_TPs - (n_TPs-1);
end
for i = 1:n_TPs-1
ind_age = [ind_age; ind_age+i*offset_TPs];
end
else
[~, ind_age] = sort(age);
end
min_hyperparam = cell(numel(D.res_array),numel(D.smooth_array),numel(D.seg_array),numel(D.train_array));
for rep = 1:D.k_fold_reps
% control random number generation to get always the same seeds
if isfield(D,'k_fold_rand')
if exist('rng','file') == 2
rng('default')
rng(D.k_fold_rand)
else
rand('state',D.k_fold_rand);
end
ind_age = randperm(numel(age))';
end
% use random indexing of age for repeated k-fold cross-validation
if D.k_fold_reps > 1 && ~isfield(D,'k_fold_rand')
if exist('rng','file') == 2
rng('default')
rng(rep)
else
rand('state',rep);
end
ind_age = randperm(numel(age))';
end
for j = 1:D.k_fold
% indicate that validation is running and will not be called in nested loops
D.run_kfold = j;
D.run_repetition = rep;
ind_fold0 = j:D.k_fold:D.n_data;
% try to use similar age distribution between folds
% by using sorted age index
ind_test = ind_age(ind_fold0)';
% build training sample using remaining subjects
ind_train = ind_age';
ind_train(ind_fold0) = [];
% I know this should never happen, but be absolutely sure we check
% whether there is some overlap between training and test data
n_overlaps = sum(ismember(ind_train,ind_test));
if n_overlaps
if exist('cat_io_cprintf','var')
cat_io_cprintf('warn',sprintf('WARNING: There is an overlap of %d subjects between training and test data.\n',n_overlaps));
else
fprintf('WARNING: There is an overlap of %d subjects between training and test data.\n',n_overlaps);
end
end
% build indices for training and test
ind_train_array{j} = ind_train;
ind_test_array{j} = ind_test;
% collect age and indces in the order w.r.t. the folding
age_all = [age_all; age(ind_test)];
ind_all = [ind_all ind_test];
% prepare BA_gpr_ui parameters
D.ind_groups = {ind_test};
D.ind_adjust = ind_test;
if isfield(D,'fraction')
D.ind_train = ind_train(1:D.fraction:end);
else
D.ind_train = ind_train;
end
% call nested loop
[BA_fold_all, ~, ~, D] = BA_gpr_ui(D);
if j == 1 && rep == 1
BA_all = zeros(D.n_data,D.k_fold,D.k_fold_reps,D.n_regions,size(BA_fold_all,2)/D.n_regions);
end
% we keep entries for each loop because there might be some overlapping
% entries
BA_all(ind_test,j,rep,:,:) = reshape(BA_fold_all,size(BA_fold_all,1),D.n_regions,size(BA_fold_all,2)/D.n_regions);
end
end
% we have to estimate the mean by using the sum and dividing by the
% actual numbers of entries (~=0)
n_entries = sum(BA_all(:,:,1)~=0,2);
BA_all0 = squeeze(sum(BA_all,2))./n_entries;
BA_all = zeros(D.n_data,size(BA_fold_all,2)/D.n_regions,D.n_regions);
for k = 1:D.n_data
BA_all(k,:,:) = squeeze(BA_all0(k,:,:))';
end
if any(n_entries>1)
if min(n_entries) == max(n_entries)
fprintf('There are %d overlapping entries that were averaged.\n',max(n_entries));
else
fprintf('There are %d-%d overlapping entries that were averaged.\n',max(n_entries),max(n_entries));
end
end
% we use the mean over all repetitions
if D.k_fold_reps > 1
BA_all = squeeze(mean(BA_all,2));
end
D.ind_adjust = ind_adjust; % rescue original ind_adjust
% go through different ensembles if defined
D0 = D;
D.MAE = [];
if isfield(D,'ensemble')
BA_unsorted_weighted = [];
for m = 1:numel(D.ensemble)
D0.ensemble = D.ensemble(m);
% for GPR stacking we have to initially apply trend correction (to
% all single models)
if D.trend_degree >= 0 && D0.ensemble == 4
BA_all1 = BA_all;
for i = 1:size(BA_all,2)
BA_all1(:,i) = apply_trend_correction(BA_all(:,i),age,D,0);
end
[~, PredictedAge_unsorted_weighted] = ensemble_models(BA_all1,age,D0,ind_test_array,ind_train_array);
else
[~, PredictedAge_unsorted_weighted] = ensemble_models(BA_all,age,D0,ind_test_array,ind_train_array);
end
BA_unsorted_weighted0 = PredictedAge_unsorted_weighted-age;
if D.verbose > 0 && D.trend_degree >= 0 && ~isfield(D,'define_cov')
fprintf('\n===========================================================\n');
fprintf(ensemble_str{D0.ensemble+1}); fprintf('\n');
str_trend = {'No age correction','Age correction using BA','Age correction using PredictedAge (Cole)'};
co = 0:2;
co(co == D.trend_method) = [];
for i=co
D1 = D;
D1.trend_method = i;
BA_unsorted_weighted1 = apply_trend_correction(BA_unsorted_weighted0,age,D1);
fprintf('\n%s:\n',str_trend{i+1});
MAE_weighted = mean(abs(BA_unsorted_weighted1));
cc = corrcoef([BA_unsorted_weighted1+age age]);
fprintf('Overall weighted MAE for %d-fold = ',D.k_fold);
fprintf('%g ',MAE_weighted); fprintf('\n');
fprintf('Overall weighted correlation for %d-fold = ',D.k_fold);
fprintf('%g ',cc(end,1:end-1)); fprintf('\n');
end
fprintf('\n===========================================================\n');
end
% apply final trend correction to weighted model
if D.trend_degree >= 0
[BA_unsorted_weighted0,~,Adjustment] = apply_trend_correction(BA_unsorted_weighted0,age,D);
end
MAE_weighted = mean(abs(BA_unsorted_weighted0));
D.MAE = [D.MAE MAE_weighted];
if isfield(D,'define_cov')
cc = corrcoef([BA_unsorted_weighted0 age]);
else
cc = corrcoef([BA_unsorted_weighted0+age age]);
end
fprintf('\n===========================================================\n');
fprintf('%s (ensemble=%d)\n',ensemble_str{D0.ensemble+1},D0.ensemble);
if ~isfield(D,'define_cov')
fprintf('Overall weighted MAE for %d-fold = ',D.k_fold);
fprintf('%g ',MAE_weighted); fprintf('\n');
end
fprintf('Overall weighted correlation for %d-fold = ',D.k_fold);
fprintf('%g ',cc(end,1:end-1));
fprintf('\n============================================================\n\n');
BA_unsorted_weighted = [BA_unsorted_weighted BA_unsorted_weighted0];
end
else
BA_unsorted_weighted = BA_all;
% apply trend correction
if D.trend_degree >= 0
[BA_unsorted_weighted,~,Adjustment] = apply_trend_correction(BA_unsorted_weighted,age,D);
end
% only print performance for single model
if size(BA_unsorted_weighted,2) == 1
D.MAE = mean(abs(BA_unsorted_weighted));
if isfield(D,'define_cov')
ind = ~isnan(age);
cc = corrcoef(BA_unsorted_weighted(ind),age(ind));
else
cc = corrcoef(BA_unsorted_weighted+age,age);
end
fprintf('\n===========================================================\n');
if ~isfield(D,'define_cov')
fprintf('Overall MAE for %d-fold = %g\n',D.k_fold,D.MAE);
end
fprintf('Overall correlation for %d-fold = %g\n',D.k_fold,cc(1,2));
fprintf('============================================================\n\n');
end
end
BrainAGE_unsorted = BA_unsorted_weighted;
BrainAGE = BrainAGE_unsorted;
if nargout > 2
% if site_adjust is empty we don't apply site adjustment
if isempty(D.site_adjust)
site_adjust = ones(D.n_data,1);
else
site_adjust = D.site_adjust;
end
% apply trend correction only for non-weighted data
if D.trend_degree >= 0
% apply trend correction for each site separately
for i = 1:max(site_adjust)
ind_site = find(site_adjust == i);
for j = 1:size(BA_all,2)
BA_all(ind_site,j,:) = squeeze(BA_all(ind_site,j,:)) - Adjustment{i};
end
end
end
BrainAGE_all = BA_all;
end
return
end
% check whether additional fields for weighted BA are available
multiple_BA = numel(D.smooth_array) > 1 || numel(D.seg_array) > 1 || numel(D.res_array) > 1;
% prepare output
BA = [];
BA_unsorted = [];
PredictedAge_unsorted = [];
EA = [];
if ~exist('min_hyperparam','var')
min_hyperparam = cell(numel(D.res_array),numel(D.smooth_array),numel(D.seg_array),numel(D.train_array));
end
% go through all resolutions, smoothing sizes, segmentations and training samples
for i = 1:numel(D.res_array)
for j = 1:numel(D.smooth_array)
for k = 1:numel(D.seg_array)
for q = 1:numel(D.train_array)
% select current data
D.res = D.res_array{i};
D.smooth = D.smooth_array{j};
seg = D.seg_array{k};
training_sample = D.train_array{q};
% remove old D.training_sample field if exist
if isfield(D,'training_sample')
D = rmfield(D,'training_sample');
end
% remove old D.seg field if exist
if isfield(D,'seg')
D = rmfield(D,'seg');
end
% find potential "+" indicating to combine training sample
ind_plus = strfind(training_sample,'+');
if ~isempty(ind_plus)
ind_plus = [0 ind_plus length(training_sample)+1];
n_train = numel(ind_plus)-1;
for l = 1:n_train
D.training_sample{l} = training_sample(ind_plus(l)+1:ind_plus(l+1)-1);
end
else
D.training_sample{1} = training_sample;
end
% find potential "+" indicating to combine segmentations
ind_plus = strfind(seg,'+');
if ~isempty(ind_plus)
ind_plus = [0 ind_plus length(seg)+1];
n_seg = numel(ind_plus)-1;
for l = 1:n_seg
D.seg{l} = seg(ind_plus(l)+1:ind_plus(l+1)-1);
end
else
n_seg = 1;
D.seg{1} = seg;
end
% clear old male parameter
if exist('male','var')
clear male
end
% find potential "+" indicating to combine data
ind_plus = strfind(D.data,'+');
if ~isempty(ind_plus)
Y0 = [];
age0 = [];
male0 = [];
name0 = [];
site_adjust = [];
ind_plus = [0 ind_plus length(D.data)+1];
n_train = numel(ind_plus)-1;
for l = 1:n_train
name = [D.smooth D.seg{1} '_' D.res 'mm_' D.data(ind_plus(l)+1:ind_plus(l+1)-1) D.relnumber];
if D.verbose > 1, fprintf('BA_gpr_ui: load %s\n',name); end
load(name);
name0 = [name0 '+' name];
age0 = [age0; age];
male0 = [male0; male];
Y0 = [Y0; single(Y)]; clear Y
site_adjust = [site_adjust; l*ones(size(age))];
end
age = age0;
male = male0;
name = name0(2:end); % remove leading '+'
Y = Y0; clear Y0
% create D.site_adjust if not already defined for more than one site
if max(D.site_adjust) == 1
D.site_adjust = site_adjust;
elseif ~isempty(D.site_adjust)
if i==1 && j==1 && k==1 && q==1
fprintf('\n-----------------------------------------------------------------\n');
fprintf('Please ensure that site-specific adjustment is correctly defined also for each training sample!\n');
fprintf('\n-----------------------------------------------------------------\n');
end
end
else
name = [D.smooth D.seg{1} '_' D.res 'mm_' D.data D.relnumber];
if D.verbose > 1, fprintf('BA_gpr_ui: load %s\n',name); end
load(name);
end
n_data = size(Y,1);
if D.n_data ~= n_data
fprintf('\n-----------------------------------------------------------------\n');
fprintf('Data size differs for %s (%d vs. %d)',name,D.n_data,n_data);
fprintf('\n-----------------------------------------------------------------\n');
return
end
if D.verbose > 1, fprintf('\n-----------------------------------------------------------------\n%s\n',name); end
D.Y_test = single(Y); clear Y
if isfield(D,'define_cov')
age = D.define_cov;
male = ones(size(age));
end
D.age_test = age;
if exist('male','var')
D.male_test = male;
end
if ~isfield(D,'age_range')
D.age_range = [min(D.age_test) max(D.age_test)];
end
% use additional segmentation if defined
for l = 2:n_seg
% find potential "+" indicating to combine data
ind_plus = strfind(D.data,'+');
if ~isempty(ind_plus)
Y0 = [];
ind_plus = [0 ind_plus length(D.data)+1];
n_train = numel(ind_plus)-1;
for m = 1:n_train
name = [D.smooth D.seg{l} '_' D.res 'mm_' D.data(ind_plus(m)+1:ind_plus(m+1)-1) D.relnumber];
if D.verbose > 1, fprintf('BA_gpr_ui: load %s\n',name); end
load(name);
Y0 = [Y0; single(Y)]; clear Y
end
D.Y_test = [D.Y_test Y0]; clear Y0
else
name = [D.smooth D.seg{l} '_' D.res 'mm_' D.data D.relnumber];
if D.verbose > 1, fprintf('BA_gpr_ui: load %s\n',name); end
load(name);
D.Y_test = [D.Y_test single(Y)]; clear Y
end
end
if D.verbose > 1, fprintf('\n'); end
% apply comcat harmonization while preserving age effects
% do this here only if training and test data are the same (i.e. k-fold validation)
% otherwise apply comcat in BA_gpr.m
if isfield(D,'comcat') && strcmp(D.train_array{1},D.data)
if length(D.comcat) ~= length(D.age_test)
error('Size of site definition in D.comcat (n=%d) differs from sample size (n=%d)\n',...
length(D.comcat),length(D.age_test));
end
fprintf('Apply ComCat for %d site(s)\n',numel(unique([D.comcat])));
D.Y_test = cat_stat_comcat(D.Y_test, D.comcat, [], D.age_test, 0, 3, 0, 1);
end
if ~isfield(D,'ind_groups')
D.ind_groups = {1:length(D.age_test)};
end
n_groups = numel(D.ind_groups);
if ~isfield(D,'groupcolor')
try
groupcolor = cat_io_colormaps('trafficlight',n_groups);
catch
groupcolor = cat_io_colormaps('nejm',n_groups);
end
else
groupcolor = D.groupcolor;
end
if ~isfield(D,'nuisance')
D.nuisance = [];
end
% build index for test data
ind_test = [];
for o = 1:n_groups
if size(D.ind_groups{o},1) < size(D.ind_groups{o},2)
D.ind_groups{o} = D.ind_groups{o}';
end
ind_test = [ind_test; D.ind_groups{o}];
end
if ~isfield(D,'age_range')
D.age_range = [min(D.age_test(ind_test)) max(D.age_test(ind_test))];
else
if numel(D.age_range) ~=2
error('Age range has to be defined by two values (min/max)');
end
end
% add spatial resolution to atlas name
if isfield(D,'parcellation') && D.parcellation
atlas_name = ['Brain_Lobes_' D.res 'mm.mat'];
load(atlas_name)
if ~exist('atlas','var')
error('Atlas must contain atlas as variable');
end
regions = unique(atlas(atlas > 0));
D.n_regions = numel(regions);
BrainAGE = [];
for r = 1:D.n_regions
D.mask = atlas == regions(r);
[tmp, ~, D] = BA_gpr(D);
BrainAGE = [BrainAGE tmp];
end
else
[BrainAGE, ~, D] = BA_gpr(D);
D.n_regions = 1;
end
% print information about training sample only once for D.threshold_std == Inf or otherwise always
if D.verbose && (i==1 && j==1 && k==1) || isfinite(D.threshold_std)
fprintf('\n%d subjects used for training (age %3.1f..%3.1f years)\n',length(D.age_train),min(D.age_train),max(D.age_train));
fprintf('Mean age\t%g (SD %g) years\nMales/Females\t%d/%d\n',mean(D.age_train),std(D.age_train),sum(D.male_train),length(D.age_train)-sum(D.male_train));
if ~isfinite(D.threshold_std)
fprintf('\n');
end
fprintf('\n%d subjects used for prediction (age %3.1f..%3.1f years)\n',length(age),min(age),max(age));
if exist('male','var')
fprintf('Mean age\t%g (SD %g) years\nMales/Females\t%d/%d\n',mean(age),std(age),sum(male),length(age)-sum(male));
else
fprintf('Mean age\t%g (SD %g) years\n',mean(age),std(age));
end
if ~isfinite(D.threshold_std)
fprintf('\n');
end
end
% move on if training failed for global BrainAGE
if size(BrainAGE,2) == 1 && (all(isnan(BrainAGE)) || std(BrainAGE)==0)
BrainAGE_all = BrainAGE;
BA_unsorted = [BA_unsorted, BrainAGE];
if isfield(D,'define_cov')
PredictedAge_unsorted = [PredictedAge_unsorted, BrainAGE];
else