Skip to content

Commit

Permalink
small changes to sync with Python version
Browse files Browse the repository at this point in the history
  • Loading branch information
latimerk committed Aug 2, 2022
1 parent 67297b5 commit bed23ec
Show file tree
Hide file tree
Showing 24 changed files with 512 additions and 612 deletions.
30 changes: 13 additions & 17 deletions +kgmlm/+fittingTools/HMCstep_diag.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
nlpost = inf; % divergent trajectory
break;
end

%% move positions
[w, errs] = paramStep(w, p, M, HMC_state);
if(errs)
Expand All @@ -48,14 +49,14 @@

[nlpost, ndW, ~, results] = nlpostFunction(w);


%% move momentums
[p, errs] = momentumStep(p, -ndW, HMC_state);
if(errs)% divergent trajectory
nlpost = inf; % divergent trajectory
break;
end

%% check current divergence
lp_momentum = logProbMomentum(p, M);
H_s = -nlpost + lp_momentum;

Expand All @@ -73,21 +74,16 @@
error('HMC accept probability is nan!');
end
catch ee %#ok<NASGU>
%p_accept = 1e-14;

log_p_accept = nan;%log(p_accept);
w_new = w_init;
results = results_init;
accepted = false;
divergent = true;

% msgText = getReport(ee,'extended');
% fprintf('HMC reaching inf/nan values with step size %.4f: %s\n\tAuto-rejecting sample and setting p_accept = %e.\n\tError Message: %s\n',ees,errorMessageStr,p_accept,msgText);
% fprintf('>>end error message<<\n');

% fprintf('\t\t>>>HMC sampler reaching numerically unstable values (infinite/nan): rejecting sample early<<<\n');

log_p_accept = nan;
w_new = w_init;
results = results_init;
accepted = false;
divergent = true;

%msgText = getReport(ee,'extended');
%fprintf('HMC reaching inf/nan values with step size %.4f: %s\n\tAuto-rejecting sample and setting p_accept = %e.\n\tError Message: %s\n',ees,errorMessageStr,p_accept,msgText);
%fprintf('>>end error message<<\n');
%fprintf('\t\t>>>HMC sampler reaching numerically unstable values (infinite/nan): rejecting sample early<<<\n');
return;
end

Expand All @@ -108,16 +104,16 @@
end
end


%% gets initial momentum
function [vv] = generateMomentum(M)
vv = (randn(numel(M),1).*sqrt(M));
end

%% gets the probability of a momentum term
function [lp] = logProbMomentum(mm,M)
lp = -1/2*sum(M.\mm.^2);
end


%% complete parameter step
function [w,errs] = paramStep(w, p, M, HMC_state)
w(:) = w + HMC_state.stepSize.e*(M.\p(:));
Expand Down
12 changes: 8 additions & 4 deletions +kgmlm/+fittingTools/adjustHMCstepSize.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ps.delta = ps.delta(min(numel(ps.delta), ww));

tt = ss - sample_1 + 1;
[stepSizeState.x_t,stepSizeState.x_bar_t,stepSizeState.H_sum] = dualAverageStepSizeUpdate_internal(ps, log_p_accept_new, stepSizeState.H_sum, stepSizeState.x_bar_t, tt);
[stepSizeState.x_t, stepSizeState.x_bar_t, stepSizeState.H_sum] = dualAverageStepSizeUpdate_internal(ps, log_p_accept_new, stepSizeState.H_sum, stepSizeState.x_bar_t, tt, log(stepSizeSettings.max_step_size));
stepSizeState.e_bar = exp(stepSizeState.x_bar_t);

if(ss == stepSizeSettings.schedule(ww,2))
Expand All @@ -32,8 +32,8 @@
end


stepSizeState.e = min(stepSizeSettings.max_step_size, stepSizeState.e );
stepSizeState.e_bar = min(stepSizeSettings.max_step_size, stepSizeState.e_bar );
% stepSizeState.e = min(stepSizeSettings.max_step_size, stepSizeState.e );
% stepSizeState.e_bar = min(stepSizeSettings.max_step_size, stepSizeState.e_bar );

HMC_state.stepSize = stepSizeState;

Expand All @@ -50,7 +50,7 @@
end

%%
function [x_t, x_bar_t, H_sum] = dualAverageStepSizeUpdate_internal(ps, log_h, H_sum, x_bar_t, tt)
function [x_t, x_bar_t, H_sum] = dualAverageStepSizeUpdate_internal(ps, log_h, H_sum, x_bar_t, tt, max_x)
if(tt == 1 || isnan(x_bar_t) || isinf(x_bar_t))
%reset estimation
x_t = x_bar_t;
Expand All @@ -72,6 +72,10 @@
H_tt = ps.delta - a_tt;
H_sum = aa_H * H_tt + (1 - aa_H) * H_sum;
x_t = ps.mu - sqrt(tt)/ps.gamma * H_sum;
if(nargin >= 6)
x_t = min(max_x, x_t);
end

aa_x =tt^(-ps.kappa);
x_bar_t = aa_x * x_t + (1 - aa_x)*x_bar_t;
end
7 changes: 2 additions & 5 deletions +kgmlm/+utils/logMeanExp.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
end


if(strcmpi(dim, 'all'))
NE = numel(log_x);
else
NE = size(log_x, dim);
end
NE = sum(~isnan(log_x), dim);
NE(NE == 0) = nan;

log_m = -log(NE) + kgmlm.utils.logSumExp(log_x,dim);
12 changes: 8 additions & 4 deletions +kgmlm/+utils/logSumExp.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
dim = 2;
end

if(size(log_x, dim) == 1)
log_m = log_x;
return;
end

cs = max(log_x, [], dim);
cs = max(log_x, [], dim, 'omitnan');
log_x = log_x - cs;

if(isnumeric(dim) && dim == 1 && isa(log_x,'single'))
%converts to double for numerical saftey without converting the entire matrix of log_x - just one column at a time
log_m = cs;
for ii = 1:numel(log_m)
log_m(ii) = log_m(ii) + log(sum(exp(double(log_x(:,ii)))));
log_m(ii) = log_m(ii) + log(sum(exp(double(log_x(:,ii))), 'omitnan'));
end
elseif(isnumeric(dim) && dim == 2 && isa(log_x,'single'))
log_m = cs;
for ii = 1:numel(log_m)
log_m(ii) = log_m(ii) + log(sum(exp(double(log_x(ii,:)))));
log_m(ii) = log_m(ii) + log(sum(exp(double(log_x(ii,:))), 'omitnan'));
end
else
log_m = cs +log(sum(exp(log_x),dim));
log_m = cs +log( sum(exp(log_x),dim, 'omitnan'));
end


Expand Down
11 changes: 8 additions & 3 deletions +kgmlm/+utils/truncatedPoiss.m
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
function [ll] = truncatedPoiss(rr, Y)
function [ll, p0, log_p0] = truncatedPoiss(rr, Y)

ll = zeros(size(Y));

Y1 = Y > 0;
Y0 = Y == 0;

rl = rr < -30;
ll(Y1 & rl) = rr(Y1 & rl);
ll(Y1 & rl) = rr(Y1 & rl);
ll(Y1 & ~rl) = log(1-exp(-exp(rr(Y1 & ~rl))));

ll(Y0) = -exp(rr(Y0));
ll(Y0) = -exp(rr(Y0));

if(nargout > 1)
log_p0 = -exp(rr);
p0 = exp(log_p0);
end
100 changes: 53 additions & 47 deletions +kgmlm/@GMLM/GMLM.m
Original file line number Diff line number Diff line change
Expand Up @@ -1193,14 +1193,14 @@ function delete(obj)

J = obj.dim_J;
params_0 = params;
if(isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.scaleParams))
params = obj.GMLMstructure.scaleParams(params_0);
end
for jj = 1:J
if(isfield(obj.GMLMstructure.Groups(jj), "scaleParams") && ~isempty(obj.GMLMstructure.Groups(jj).scaleParams))
params.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleParams(params_0.Groups(jj));
end
end
if(isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.GMLMstructure.scaleParams))
params = obj.GMLMstructure.scaleParams(params);
end

for jj = 1:obj.dim_J
shared_regressors(jj).F = obj.getF(params, jj);
Expand Down Expand Up @@ -1313,6 +1313,7 @@ function delete(obj)

for pp = 1:size(obj.trials(mm).Y, 2)
rr = log_like_per_trial(mm).log_rate(:, pp) + log(obj.bin_size);


log_like_per_trial(mm).log_like_0(:, pp) = kgmlm.utils.truncatedPoiss(rr, obj.trials(mm).Y(:, pp) );

Expand Down Expand Up @@ -1889,37 +1890,20 @@ function delete(obj)
useAsync = true;
params_0 = params;

scaled_WB = isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.scaleParams);
if(scaled_WB)
params = obj.GMLMstructure.scaleParams(params_0);

if(isfield(opts, "dH"))
opts.dB = opts.dB | opts.dH;
opts.dW = opts.dW | opts.dH;

if(nargin > 3)
if(opts.dW && isempty(results.dW))
error("Invalid results struct.");
end
if(opts.dB && isempty(results.dB) && obj.dim_B > 0)
error("Invalid results struct.");
end
end
end
end
J = obj.dim_J;
scaled_VT = false(J,1);
scaleP = cell(J,1);
for jj = 1:J
if(isfield(obj.GMLMstructure.Groups(jj), "scaleParams") && ~isempty(obj.GMLMstructure.Groups(jj).scaleParams))

scaled_VT(jj) = true;
params.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleParams(params_0.Groups(jj));
[params.Groups(jj), scaleP{jj}] = obj.GMLMstructure.Groups(jj).scaleParams(params_0.Groups(jj));

if(isfield(opts.Groups(jj), "dH"))
opts.Groups(jj).dV = opts.Groups(jj).dV | opts.Groups(jj).dH;
opts.Groups(jj).dT(:) = opts.Groups(jj).dT(:) | opts.Groups(jj).dH;

if(nargin > 3)
if(nargin > 3 && ~isempty(results))
if(opts.Groups(jj).dV && isempty(results.Groups(jj).dV))
error("Invalid results struct.");
end
Expand All @@ -1932,6 +1916,24 @@ function delete(obj)
end
end
end
scaled_WB = isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.GMLMstructure.scaleParams);
if(scaled_WB)
[params, scaleP_WB] = obj.GMLMstructure.scaleParams(params);

if(isfield(opts, "dH"))
opts.dB = opts.dB | opts.dH;
opts.dW = opts.dW | opts.dH;

if(nargin > 3 && ~isempty(results))
if(opts.dW && isempty(results.dW))
error("Invalid results struct.");
end
if(opts.dB && isempty(results.dB) && obj.dim_B > 0)
error("Invalid results struct.");
end
end
end
end

if(runHost)
if(nargin < 4 || isempty(results))
Expand Down Expand Up @@ -1963,11 +1965,11 @@ function delete(obj)


if(scaled_WB)
results = obj.Groups(jj).scaleDerivatives(results, params_0, true);
results = obj.GMLMstructure.scaleDerivatives(results, params_0, true, scaleP_WB);
end
for jj = 1:J
if(scaled_VT(jj))
results.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleDerivatives(results.Groups(jj), params_0.Groups(jj), false);
results.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleDerivatives(results.Groups(jj), params_0.Groups(jj), false, scaleP{jj});
end
end
results.log_likelihood = sum(results.trialLL, 'all');
Expand All @@ -1984,38 +1986,21 @@ function delete(obj)
useAsync = true;
params_0 = params;

scaled_WB = isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.scaleParams);
if(scaled_WB)
params = obj.GMLMstructure.scaleParams(params_0);

if(isfield(opts, "dH"))
opts.dB = opts.dB | opts.dH;
opts.dW = opts.dW | opts.dH;

if(nargin > 3)
if(opts.dW && isempty(results.dW))
error("Invalid results struct.");
end
if(opts.dB && isempty(results.dB) && obj.dim_B > 0)
error("Invalid results struct.");
end
end
end
end

J = obj.dim_J;
scaled_VT = false(J,1);
scaleP = cell(J,1);
for jj = 1:J
if(isfield(obj.GMLMstructure.Groups(jj), "scaleParams") && ~isempty(obj.GMLMstructure.Groups(jj).scaleParams))

scaled_VT(jj) = true;
params.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleParams(params_0.Groups(jj));
[params.Groups(jj), scaleP{jj}] = obj.GMLMstructure.Groups(jj).scaleParams(params_0.Groups(jj));

if(isfield(opts.Groups(jj), "dH"))
opts.Groups(jj).dV = opts.Groups(jj).dV | opts.Groups(jj).dH;
opts.Groups(jj).dT(:) = opts.Groups(jj).dT(:) | opts.Groups(jj).dH;

if(nargin > 3)
if(nargin > 3 && ~isempty(results))
if(opts.Groups(jj).dV && isempty(results.Groups(jj).dV))
error("Invalid results struct.");
end
Expand All @@ -2028,6 +2013,24 @@ function delete(obj)
end
end
end
scaled_WB = isfield(obj.GMLMstructure, "scaleParams") && ~isempty(obj.GMLMstructure.scaleParams);
if(scaled_WB)
[params, scaleP_WB] = obj.GMLMstructure.scaleParams(params);

if(isfield(opts, "dH"))
opts.dB = opts.dB | opts.dH;
opts.dW = opts.dW | opts.dH;

if(nargin > 3 && ~isempty(results))
if(opts.dW && isempty(results.dW))
error("Invalid results struct.");
end
if(opts.dB && isempty(results.dB) && obj.dim_B > 0)
error("Invalid results struct.");
end
end
end
end

if(runHost)
if(nargin < 4 || isempty(results))
Expand Down Expand Up @@ -2063,11 +2066,11 @@ function delete(obj)
end

if(scaled_WB)
results = obj.Groups(jj).scaleDerivatives(results, params_0, true);
results = obj.GMLMstructure.scaleDerivatives(results, params_0, true, scaleP_WB);
end
for jj = 1:J
if(scaled_VT(jj))
results.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleDerivatives(results.Groups(jj), params_0.Groups(jj), true);
results.Groups(jj) = obj.GMLMstructure.Groups(jj).scaleDerivatives(results.Groups(jj), params_0.Groups(jj), true, scaleP{jj});
end
end

Expand Down Expand Up @@ -2173,10 +2176,13 @@ function delete(obj)

[samples, summary, HMC_settings, paramStruct, M] = runHMC_simple(obj, params_init, settings, varargin);
[samples, samples_file_format, summary, HMC_settings, paramStruct, M] = runHMC_simpleLowerRAM(obj, params_init, settings, varargin);
[paramStruct2] = saveSampleToFile(obj, samples_file, paramStruct, sample_idx, scaled_WB, scaled_VT, save_H, saveUnscaled);
[samples_file_format, totalParams] = getSampleFileFormat(obj, TotalSamples, dataType_samples, paramStruct, scaled_WB, scaled_VT, saveUnscaled)
[HMC_settings] = setupHMCparams(obj, nWarmup, nSamples, debugSettings);


[results] = computeLogLikelihood_host_v2(obj, params, opts, results);
[log_rate, xx, R] = computeLogRate_host_v2(obj, params);
[] = setupComputeStructuresHost(obj, reset, order);
end

Expand Down
Loading

0 comments on commit bed23ec

Please sign in to comment.