Skip to content

Commit

Permalink
speed optimization for svd and clustering (mostly for GPU users)
Browse files Browse the repository at this point in the history
- large matrix multiplications are now done on GPU (usually x2
improvement in speed for 'single' arrays)
- slightly better vectorization and memory management
- saving with '-v6' when possible (not everywhere yet)
  • Loading branch information
mkrumin committed Nov 2, 2016
1 parent 871aa94 commit 6885ff7
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 66 deletions.
48 changes: 29 additions & 19 deletions cellDetection/fastClustNeuropilCoef.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [ops, stat, res] = fastClustNeuropilCoef(ops, U, Sv)
%
%
U = reshape(U, [], size(U,ndims(U)));
iplane = ops.iplane;
U = bsxfun(@times, U, Sv'.^.5)';
Expand All @@ -18,21 +18,21 @@

xs = repmat(round(linspace(1, nsqrt, Lx)), Ly, 1);
ys = repmat(round(linspace(1, nsqrt, Ly))', 1, Lx);
iclust = xs + (ys-1) * nsqrt;
iclust = xs + (ys-1) * nsqrt;

clear xs ys

niter = ops.niterclustering;

% xs = repmat(1:Lx, Ly, 1);
% ys = repmat((1:Ly)', 1, Lx);
%
%
% randx = rand(1, Nk) * Lx;
% randy = rand(1, Nk) * Ly;
%
%
% dx = repmat(xs(:), 1, Nk) - repmat(randx, numel(xs(:)), 1);
% dy = repmat(ys(:), 1, Nk) - repmat(randy, numel(ys(:)), 1);
%
%
% dxy = dx.^2 + dy.^2;
% [~, iclust] = min(dxy, [], 2);
%%
Expand Down Expand Up @@ -61,7 +61,7 @@

ison = true(Nk,1);
TileFactor = getOr(ops, {'TileFactor'}, 1); % this option can be overwritten by the user
nTiles = ceil(TileFactor * (Ly+Lx)/2 / (10 * ops.diameter)); % neuropil is modelled as nTiles by nTiles
nTiles = ceil(TileFactor * (Ly+Lx)/2 / (10 * ops.diameter)); % neuropil is modelled as nTiles by nTiles

xc = linspace(1, Lx, nTiles);
yc = linspace(1, Ly, nTiles);
Expand Down Expand Up @@ -106,7 +106,7 @@
StU = Sm' * Uneu';
Lam = (StS + 1e-4 * eye(nBasis)) \ StU;

% recompute neuropil pixel contribution
% recompute neuropil pixel contribution
neuropil = Lam' * S';
PixL = mean(bsxfun(@times, neuropil, Uneu), 1);
PixL = bsxfun(@rdivide, PixL, mean(neuropil.^2,1));
Expand All @@ -124,18 +124,28 @@
end
end
%
% Ff = Fs' * vs;
% vs = Finv * max(0, Ff);
% [dcell, Ffr] = run_deconvolution2(Ff, f0, kernel);
% vs = Finv * Ffr;
% Ff = Fs' * vs;
% vs = Finv * max(0, Ff);
% [dcell, Ffr] = run_deconvolution2(Ff, f0, kernel);
% vs = Finv * Ffr;

vs = bsxfun(@rdivide, vs, sum(vs.^2,1).^.5 + 1e-8);% normalize activity vectors
vs = single(vs);

% recompute pixels' assignments
xs = vs' * Ucell;
if ops.useGPU
xs = gpuBlockSmallXtY(vs, Ucell);
else
xs = vs' * Ucell;
end

[M, iclust] = max(xs,[],1);
% Uneu = U - bsxfun(@times, M, vs(:,iclust)); %what's left over for neuropil model
Uneu = U - bsxfun(@times, M, vs(:,iclust)); %what's left over for neuropil model
err(k) = sum(sum((Uneu-neuropil).^2)).^.5;
% vs = double(vs);

% err(k) = sum(sum((Uneu-neuropil).^2)).^.5;
err(k) = norm(Uneu(:)-neuropil(:));

if 1
%---------------------------------------------%
Expand Down Expand Up @@ -177,20 +187,20 @@


if (rem(k,10)==1 || k==niter) && ops.ShowCellMap
% imagesc(reshape(PixL, Ly, Lx), [0 2])
% drawnow
%
% imagesc(reshape(PixL, Ly, Lx), [0 2])
% drawnow
%
lam = M;
for i = 1:Nk
ix = find(iclust==i);
nT0 = numel(ix);
if nT0>0
vM = lam(ix);
% vM = vM/sum(vM.^2)^.5;
% vM = vM/sum(vM.^2)^.5;
lam(ix) = vM;
end
end
% V = max(0, min(10 * reshape(lam, Ly, Lx), 1));
% V = max(0, min(10 * reshape(lam, Ly, Lx), 1));
V = max(0, min(.5 * reshape(lam, Ly, Lx)/mean(lam(:)), 1));
H = reshape(r(iclust), Ly, Lx);
rgb_image = hsv2rgb(cat(3, H, Sat, V));
Expand All @@ -199,7 +209,7 @@
drawnow
fprintf('residual variance is %2.6f time %2.2f \n', err(k), toc)
end

end

lam = M;
Expand Down
14 changes: 8 additions & 6 deletions registration/reg2P.m
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@
end

if ops.showTargetRegistration
figure('position', [900 50 900 900])
ax = ceil(sqrt(numel(ops1)/2));
nRows = floor(sqrt(numel(ops1)));
nColumns = ceil(numel(ops1)/nRows);
figure('Name', 'Registration Target Frames', ...
'Position', [50 50 nColumns*500 nRows*500])
i0 = 0;
for i = 1:numPlanes
for j = 1:size(xFOVs,2)
i0 = i0+1;
subplot(ax,2*ax,i0)
imagesc(ops1{i,j}.mimg)
subplot(nRows, nColumns, i0);
imagesc(ops1{i,j}.mimg);
colormap('gray')
title(sprintf('Registration for plane %d, mouse %s, date %s', ...
i, ops.mouse_name, ops.date))
axis equal tight
title(sprintf('Plane %d, %s_%s', i, ops.mouse_name, ops.date))
end
end

Expand Down
26 changes: 17 additions & 9 deletions signalExtraction/get_regions.m
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
for k = 1:Nk
% needs stat, res, neigh
pixall = stat(k).ipix;

% minV = clustrules.parent.minPixRelVar * mean(res.M(pixall));
% pixall(res.M(pixall)< minV) = [];
% minV = clustrules.parent.minPixRelVar * mean(res.M(pixall));
% pixall(res.M(pixall)< minV) = [];

whclust = 0;
region = [];
Expand All @@ -47,15 +47,23 @@

x0 = xs(pixi); y0 = ys(pixi);

rs = ((x0 - mean(x0)).^2 + (y0 - mean(y0)).^2).^.5;
region(whclust).mrs = mean(rs);
meanX0 = mean(x0);
meanY0 = mean(y0);
rs = ((x0 - meanX0).^2 + (y0 - meanY0).^2).^.5;
region(whclust).npix = numel(pixi);
region(whclust).mrs0 = mean(rgridsort(1:region(whclust).npix));
region(whclust).med = [mean(y0) mean(x0)];
if region(whclust).npix>1
region(whclust).mrs = mean(rs);
region(whclust).mrs0 = mean(rgridsort(1:region(whclust).npix));
region(whclust).V = sum(res.M(pixi));
else
region(whclust).mrs = rs;
region(whclust).mrs0 = rgridsort(1);
region(whclust).V = res.M(pixi);
end
region(whclust).med = [meanY0 meanX0];
region(whclust).ipix = pixi;
region(whclust).lambda = res.lambda(pixi);
region(whclust).V = sum(res.M(pixi));
region(whclust).parent = k;
region(whclust).parent = k;
end
stat(k).region = region;

Expand Down
47 changes: 33 additions & 14 deletions svd/get_svdForROI.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ix = 0;
fid = fopen(ops.RegFile, 'r');

mov = zeros(Ly, Lx, ops.NavgFramesSVD, 'single');
mov = zeros(numel(ops.yrange), numel(ops.xrange), ops.NavgFramesSVD, 'single');

while 1
data = fread(fid, Ly*Lx*nimgbatch, '*int16');
Expand All @@ -30,21 +30,21 @@
data = bsxfun(@minus, data, mean(data,3));
% data = bsxfun(@minus, data, ops.mimg1);

irange = 1:nt0*floor(size(data,3)/nt0);
data = data(:,:, irange);
nSlices = nt0*floor(size(data,3)/nt0);
if nSlices~=size(data,3)
data = data(:,:, 1:nSlices);
end

data = reshape(data, Ly, Lx, nt0, []);
davg = single(squeeze(mean(data,3)));
davg = squeeze(mean(data,3));

mov(:,:,ix + (1:size(davg,3))) = davg;
mov(:,:,ix + (1:size(davg,3))) = davg(ops.yrange, ops.xrange, :);

ix = ix + size(davg,3);
end
fclose(fid);
%%
mov(:, :, (ix+1):end) = [];

mov = mov(ops.yrange, ops.xrange, :);
mov = mov(:, :, 1:ix);
% mov = mov - repmat(mean(mov,3), 1, 1, size(mov,3));
%% SVD options

Expand All @@ -60,8 +60,13 @@

mov = reshape(mov, [], size(mov,3));
sdmov = mean(mov.^2,2).^.5;
mov = mov./repmat(sdmov, 1, size(mov,2));
COV = mov' * mov/size(mov,1);
% mov = mov./repmat(sdmov, 1, size(mov,2));
mov = bsxfun(@rdivide, mov, sdmov);
if ops.useGPU
COV = gpuBlockXtX(mov)/size(mov,1);
else
COV = mov' * mov/size(mov,1);
end

ops.nSVDforROI = min(size(COV,1)-2, ops.nSVDforROI);

Expand All @@ -78,8 +83,13 @@
Sv = single(diag(Sv));
end

U = normc(mov * V);
if ops.useGPU
U = normc(gpuBlockXY(mov, V));
else
U = normc(mov * V);
end
U = single(U);

%%
fid = fopen(ops.RegFile, 'r');

Expand All @@ -97,7 +107,11 @@
data = bsxfun(@minus, data, mean(data,3));
% data = bsxfun(@minus, data, ops.mimg1);
data = data(ops.yrange, ops.xrange, :);
Fs(:, ix + (1:size(data,3))) = U' * reshape(data, [], size(data,3));
if ops.useGPU
Fs(:, ix + (1:size(data,3))) = gpuBlockXtY(U, reshape(data, [], size(data,3)));
else
Fs(:, ix + (1:size(data,3))) = U' * reshape(data, [], size(data,3));
end

ix = ix + size(data,3);
end
Expand All @@ -109,6 +123,11 @@
mkdir(ops.ResultsSavePath);
end
if getOr(ops, {'writeSVDroi'}, 0)
save(sprintf('%s/SVDroi_%s_%s_plane%d.mat', ops.ResultsSavePath, ...
ops.mouse_name, ops.date, ops.iplane), 'U', 'Sv', 'V', 'ops');
try
save(sprintf('%s/SVDroi_%s_%s_plane%d.mat', ops.ResultsSavePath, ...
ops.mouse_name, ops.date, ops.iplane), 'U', 'Sv', 'V', 'ops', '-v6');
catch
save(sprintf('%s/SVDroi_%s_%s_plane%d.mat', ops.ResultsSavePath, ...
ops.mouse_name, ops.date, ops.iplane), 'U', 'Sv', 'V', 'ops');
end
end
Loading

0 comments on commit 6885ff7

Please sign in to comment.