Skip to content

Commit

Permalink
Optimize matrix operations in optimizeMu()
Browse files Browse the repository at this point in the history
Remove some unnecessary copy operations to speed up runtime when D is large.
  • Loading branch information
kqshan committed May 1, 2017
1 parent a07823a commit 1bab0e3
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions @MoDT/optimizeMu.m
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
% some structural information that will be useful for our call to sparse()
% later. We don't need this for the MEX version, which uses this banded storage
% format directly.
nSuperDiag = (size(B_Q,1) - 1) / 2;
if ~self.use_mex
nSuperDiag = (size(B_Q,1) - 1) / 2;
H_i = bsxfun(@plus, (-nSuperDiag:nSuperDiag)', 1:D*T);
H_j = repmat(1:D*T, 2*nSuperDiag+1, 1);
H_mask = (H_i > 0) & (H_i <= D*T);
Expand All @@ -49,11 +49,9 @@
% Make H = Qinv_matrix + Cinv_matrix
% Get the observation information matrix
C_k = C(:,:,k);
B_C = make_Cinv_matrix( C_k, sum_wzu(:,k) );
B_C = make_Cinv_matrix( C_k, sum_wzu(:,k), nSuperDiag );
% Add the two matrices together to get H
B = B_Q;
offset = (size(B_Q,1) - size(B_C,1)) / 2;
B(offset+1:end-offset,:) = B(offset+1:end-offset,:) + B_C;
B = B_Q + B_C;

% Calculate b = C_k \ wzuY(:,:,k)
b = C_k \ wzuY(:,:,k);
Expand Down Expand Up @@ -215,15 +213,16 @@
end


function B = make_Cinv_matrix( C, wzu )
function B = make_Cinv_matrix( C, wzu, kd )
% Construct the banded representation of the observation information matrix
% B = make_Cinv_matrix( C, wzu )
% B = make_Cinv_matrix( C, wzu, kd )
%
% Returns:
% B [2*D-1 x D*T] column-major banded storage of a [D*T x D*T] matrix
% B [2*kd+1 x D*T] column-major banded storage of a [D*T x D*T] matrix
% Required arguments:
% C [D x D] Observation covariance matrix
% wzu [T x 1] weight for each time frame
% kd Number of super-diagonals for the output matrix
%
% This represents the [D*T x D*t] observation matrix:
% [ wzu(1)*C^-1 ]
Expand All @@ -241,9 +240,9 @@
Cinv = (Cinv + Cinv')/2; % Enforce symmetry
% Convert the [D x D] full format into a [2*D-1 x D] banded format
D = size(C,1);
B1 = zeros(2*D-1, D);
idx = bsxfun(@plus, (1:D)', (D-1:-1:0) + (0:D-1)*size(B1,1));
B1 = zeros(2*kd+1, D);
idx = bsxfun(@plus, (1:D)', kd + 1-(1:D) + (0:D-1)*size(B1,1));
B1(idx) = Cinv(:);
% Repeat and scale by wzu
B = kron(wzu', B1);
B = reshape(B1(:)*wzu', [2*kd+1, D*length(wzu)]);
end

0 comments on commit 1bab0e3

Please sign in to comment.