-
Notifications
You must be signed in to change notification settings - Fork 0
/
cvelnet.m
70 lines (57 loc) · 1.63 KB
/
cvelnet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
function result = cvelnet(object, lambda, x, y, weights, offset, foldid, ...
type, grouped, keep)
% Internal glmnet function. See also cvglmnet.
if nargin < 10 || isempty(keep)
keep = false;
end
typenames = struct('deviance','Mean-Squared Error','mse','Mean-Squared Error','mae','Mean Absolute Error');
if strcmp(type, 'default')
type = 'mse';
end
if ~any(strcmp(type, {'mse','mae','deviance'}))
warning('Only ''mse'', ''deviance'' or ''mae'' available for Gaussian models; ''mse'' used');
type = 'mse';
end
if ~isempty(offset)
y = y - offset;
end
predmat = NaN(length(y),length(lambda));
nfolds = max(foldid);
nlams = nfolds;
for i = 1:nfolds
which = foldid == i;
fitobj = object{i};
fitobj.offset = false;
preds = glmnetPredict(fitobj,x(which,:));
nlami = length(object{i}.lambda);
predmat(which,1:nlami) = preds;
nlams(i) = nlami;
end
N = size(y,1) - sum(isnan(predmat),1);
yy = repmat(y, 1, length(lambda));
switch type
case 'mse'
cvraw = (yy - predmat).^2;
case 'deviance'
cvraw = (yy - predmat).^2;
case 'mae'
cvraw = abs(yy - predmat);
end
if (length(y)/nfolds < 3) && grouped
warning('Option grouped=false enforced in cv.glmnet, since < 3 observations per fold');
grouped = false;
end
if (grouped)
cvob = cvcompute(cvraw,weights,foldid,nlams);
cvraw = cvob.cvraw;
weights = cvob.weights;
N = cvob.N;
end
cvm = wtmean(cvraw,weights);
sqccv = (bsxfun(@minus,cvraw,cvm)).^2;
cvsd = sqrt(wtmean(sqccv,weights)./(N-1));
result.cvm = cvm; result.cvsd = cvsd; result.name = typenames.(type);
if (keep)
result.fit_preval = predmat;
end
end