-
Notifications
You must be signed in to change notification settings - Fork 0
/
cvmultnet.m
121 lines (100 loc) · 2.9 KB
/
cvmultnet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
function result = cvmultnet(object, lambda, x, y, weights, offset, foldid, ...
type, grouped, keep)
if nargin < 10 || isempty(keep)
keep = false;
end
typenames = struct('mse','Mean-Squared Error','mae','Mean Absolute Error',...
'deviance','Multinomial Deviance','class','Misclassification Error');
if strcmp(type,'default')
type = 'deviance';
end
if ~any(strcmp(type,{'mse','mae','deviance','class'}))
warning('Only ''deviance'', ''class'', ''mse'' or ''mae'' available for multinomial models; ''deviance'' used');
type = 'deviance';
end
prob_min = 1e-5; prob_max = 1 - prob_min;
nc = size(y);
if nc(2) == 1
classes = unique(y);
nc = length(classes);
indexes = eye(nc);
y = indexes(y,:);
else
nc = nc(2);
end
is_offset = ~isempty(offset);
predmat = NaN(size(y,1),nc,length(lambda));
nfolds = max(foldid);
nlams = zeros(nfolds,1);
for i = 1:nfolds
which = foldid==i;
fitobj = object{i};
if (is_offset)
off_sub = offset(which,:);
else
off_sub = [];
end
preds = glmnetPredict(fitobj,x(which,:),[],'response',[],off_sub);
nlami = length(object{i}.lambda);
predmat(which,:,1:nlami) = preds;
nlams(i) = nlami;
end
ywt = sum(y, 2);
y = y ./ repmat(ywt,1,size(y,2));
weights = weights .* ywt;
N = size(y,1) - sum(isnan(predmat(:,1,:)),1);
bigY = repmat(y, [1,1,length(lambda)]);
switch type
case 'mse'
cvraw = squeeze(sum((bigY - predmat).^2, 2));
case 'mae'
cvraw = squeeze(sum(abs(bigY - predmat), 2));
case 'deviance'
predmat = min(max(predmat,prob_min),prob_max);
lp = bigY .* log(predmat);
ly = bigY .* log(bigY);
ly(bigY == 0) = 0;
cvraw = squeeze(sum(2 * (ly - lp), 2));
case 'class'
classid = NaN(size(y,1),length(lambda));
for i = 1:length(lambda)
classid(:,i) = glmnet_softmax(predmat(:,:,i));
end
classid = reshape(classid,[],1);
yperm = reshape(permute(bigY, [1,3,2]),[],nc);
idx = sub2ind(size(yperm), 1:length(classid), classid');
cvraw = reshape(1 - yperm(idx), [], length(lambda));
end
if (grouped)
cvob = cvcompute(cvraw, weights, foldid, nlams);
cvraw = cvob.cvraw;
weights = cvob.weights;
N = cvob.N;
end
cvm = wtmean(cvraw,weights);
sqccv = (bsxfun(@minus,cvraw,cvm)).^2;
cvsd = sqrt(wtmean(sqccv,weights)./(N-1));
result.cvm = cvm; result.cvsd = cvsd; result.name = typenames.(type);
if (keep)
result.fit_preval = predmat;
end
function result = glmnet_softmax(x)
d = size(x);
nas = any(isnan(x),2);
if any(nas)
pclass = NaN(d(1),1);
if (sum(nas) < d(1))
pclass2 = glmnet_softmax(x(~nas,:));
pclass(~nas) = pclass2;
result = pclass;
end
else
maxdist = x(:,1);
pclass = ones(d(1),1);
for i = 2:d(2)
l = x(:,i)>maxdist;
pclass(l) = i;
maxdist(l) = x(l,i);
end
result = pclass;
end