-
Notifications
You must be signed in to change notification settings - Fork 14
/
varbvspredict.m
125 lines (119 loc) · 4.44 KB
/
varbvspredict.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
%--------------------------------------------------------------------------
% varbvspredict.m: Make predictions from a model fitted by varbvs.
%--------------------------------------------------------------------------
%
% DESCRIPTION:
% Predict outcomes (Y) given the observed variables (X) and observed
% covariates (Z), and a model fitted by varbvs.
%
% USAGE:
% varbvspredict(fit, X, Z)
%
% INPUT ARGUMENTS:
% fit Output of function varbvs.
%
% X n x p input matrix, in which p is the number of variables, and n
% is the number of samples for which predictions will be made using
% the fitted model. X cannot be sparse.
%
% Z n x m covariate data matrix, where m is the number of covariates. Do
% not supply an intercept as a covariate (i.e., a column of ones),
% because an intercept is automatically included in the regression
% model. For no covariates, set Z to the empty matrix [].
%
% OUTPUT: Vector containing the predicted outcomes for all samples. For
% family = 'binomial', all vector entries are 0 or 1.
%
% DETAILS:
% For the logistic regression model, we do not provide classification
% probabilities Pr(Y = 1 | X, Z) because these probabilities are not
% necessarily calibrated under the variational approximation.
%
% The predictions are computed by averaging over the hyperparameter
% settings, treating fit.logw as (unnormalized) log-marginal
% probabilities. See varbvs for more details about correctly using
% fit.logw for approximate numerical integration over the
% hyperparameters, for example by treating these as importance
% weights.
%
% LICENSE: GPL v3
%
% DATE: February 19, 2016
%
% AUTHORS:
% Algorithm was designed by Peter Carbonetto and Matthew Stephens.
% R, MATLAB and C code was written by Peter Carbonetto.
% Depts. of Statistics and Human Genetics, University of Chicago,
% Chicago, IL, USA, and AncestryDNA, San Francisco, CA, USA
%
% REFERENCES:
% P. Carbonetto, M. Stephens (2012). Scalable variational inference
% for Bayesian variable selection in regression, and its accuracy in
% genetic association studies. Bayesian Analysis 7: 73-108.
%
% SEE ALSO:
% varbvs
%
% EXAMPLES:
% See demo_qtl.m and demo_cc.m for examples.
%
function y = varbvspredict (fit, X, Z)
% Part of the varbvs package, https://github.com/pcarbo/varbvs
%
% Copyright (C) 2012-2017, Peter Carbonetto
%
% This program is free software: you can redistribute it under the
% terms of the GNU General Public License; either version 3 of the
% License, or (at your option) any later version.
%
% This program is distributed in the hope that it will be useful, but
% WITHOUT ANY WARRANY; without even the implied warranty of
% MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. See the GNU
% General Public License for more details.
%
% Get the number of samples (n), variables (p) and hyperparameter
% settings (ns).
[n p] = size(X);
ns = numel(fit.logw);
% Input X must be single precision, and cannot be sparse.
if issparse(X)
error('Input X cannot be sparse');
end
if ~isa(X,'single')
X = single(X);
end
if (numel(fit.labels) ~= p)
error('Inputs X and fit are not compatible');
end
% If input Z is not empty, it must be double precision, and must have as
% many rows as X. Add an intercept to Z, and check the number of
% covariates.
if ~isempty(Z)
if size(Z,1) ~= n
error('Inputs X and Z do not match.');
end
Z = double(full(Z));
end
Z = [ones(n,1) Z];
if (size(Z,2) ~= size(fit.mu_cov,1))
error('Inputs Z and fit are not compatible');
end
% Get the normalized (approximate) probabilities.
w = fit.w;
% For each hyperparameter setting, and for each sample, compute the
% posterior mean estimate of Y, and then average these estimates
% over the hyperparameter settings. For the logistic regression, the
% final "averaged" estimate is obtained by collecting the "votes"
% from each hyperparameter setting, weighting the votes by the
% marginal probabilities, and outputing the estimate that wins by
% majority. The averaged estimate is computed this way because the
% estimates conditioned on each hyperparameter setting are not
% necessarily calibrated in the same way.
Y = Z*fit.mu_cov + X*(fit.alpha.*fit.mu);
if strcmp(fit.family,'gaussian')
y = Y*w(:);
elseif strcmp(fit.family,'binomial')
y = round(round(sigmoid(Y))*w(:));
else
error('Invalid setting for fit.family');
end