-
Notifications
You must be signed in to change notification settings - Fork 3
/
gd_matlab.m
130 lines (106 loc) · 3.26 KB
/
gd_matlab.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
function [x, f] = gd_matlab(funObj, funPred, x0, valid, options, varargin)
%SGD_MATLAB Gradient descent; matlab implementation.
% We assume the objective is being minimized.
%
% In:
% funObj - objective function handler;
% it returns costs and gradient in this order; precisely
% [cost, gradient] = funObj(x0, dataIndices, varargin{:})
% where x0 is a parameter, dataIndices is a set of indices chosen to
% calculate the (noisy) gradient, and varargin are extra arguments
% funPred - prediction function;
% predLabels = funPred(x0, train.examples, varargin{:})
% x0 - starting point; in parameter's space
% valid - validation set; We use the validation set to choose the best
% parameters found so far with respect to the validation set
% .examples \in R[d,n] where d is dimensionality of
% datum and n is the number of data points
% .labels \in Z[n]
% options - additional optimization settings;
% defaults are used for non-existent or blank fields
% varargin - additional arguments to the objective function [optional]
%
% Out:
% x - minimum value found
% f - fuction value at the minimum found
%
% Mateusz Malinowski
%
if ~isempty(valid)
validX = valid.examples;
validY = valid.labels;
isValidation = true;
bestValidAcc = 0;
else
isValidation = false;
end
% different options
nEpochs = options.nEpochs;
numData = options.numData;
if isfield(options, 'MaxIter')
nIterations = min(options.MaxIter, nEpochs * numData);
else
nIterations = nEpochs * numData;
end
eta0 = options.eta0;
lambda = options.lambda;
isVerbose = options.isVerbose;
trainX = [];
if isfield(options, 'trainX')
if ~isempty(options.trainX)
trainX = options.trainX;
trainY = options.trainY;
end
end
x = x0;
it = 1;
while it <= nIterations
for epochNo = 1:nEpochs
% in every epoch we re-shuffle data
dataIndices = randperm(numData);
eta = eta0 / (1 + lambda * eta0 * epochNo);
fw = 1 - eta * lambda;
% we pass over all data points
for k = 1:numData
if isVerbose
fprintf('Epoch %d; %d which is %d out of %d data points\n', ...
epochNo, dataIndices(k), k, numData);
end
% computes objective
[~, grad] = funObj(x, dataIndices(k), varargin{:});
x = fw * x - eta*grad;
it = it + 1;
if it > nIterations
break;
end
end
if isVerbose
if ~isempty(funPred) && ~isempty(trainX)
trainPred = funPred(x, trainX, varargin{:});
trainAcc = sum(trainPred == trainY) / length(trainY);
% after each epoch we report current results
fprintf('-- Accuracy on training set is: %f\n', trainAcc);
end
end
if isValidation
validPred = funPred(x, validX);
validAcc = sum(validPred == validY) / length(validY);
if validAcc > bestValidAcc
bestValidAcc = validAcc;
bestX = x;
end
if isVerbose
fprintf('-- Accuracy on validation set is: %f\n', validAcc);
fprintf('-- Best accuracy on validation set: %f\n', bestValidAcc);
end
end
end
end
if isValidation
x = bestX;
end
if nargout == 2
[~, f] = funObj(x, 1:numData, varargin{:});
end
end