Skip to content

Commit

Permalink
Merge pull request #76 from tristangdwl/feature/issue-63-fast-apply-w…
Browse files Browse the repository at this point in the history
…rapper

Feature/issue 63 fast apply wrapper
  • Loading branch information
askhamwhat authored May 28, 2024
2 parents 272b060 + 4970ac1 commit 459ddc4
Show file tree
Hide file tree
Showing 12 changed files with 1,210 additions and 201 deletions.
2 changes: 1 addition & 1 deletion chunkie/+chnk/+flam/proxyfunr.m
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
srcinfo = []; srcinfo.r = pxy; srcinfo.d = ptau;
if (size(rslf,1) == 2)
srcinfo.n = chnk.perp(ptau);
targinfo.n = chnk.perp(dslf);
end

Kpxy = kern(srcinfo,targinfo);

Kpxy = Kpxy(islfuni2,:);
Expand Down
4 changes: 4 additions & 0 deletions chunkie/+chnk/+helm2d/fmm.m
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@

pg = 0;
pgt = min(nargout, 2);
switch lower(type)
case {'sprime', 'dprime', 'cprime'}
pgt = max(pgt, 2);
end
U = hfmm2d(eps, zk, srcuse, pg, targuse, pgt);

% Assign potentials
Expand Down
218 changes: 218 additions & 0 deletions chunkie/+chnk/chunkerkerneval_smooth.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
function fints = chunkerkerneval_smooth(chnkr,kern,opdims,dens, ...
targinfo,flag,opts)

if isa(kern,'kernel')
kerneval = kern.eval;
else
kerneval = kern;
end

flam = false;
accel = true;
forcefmm = false;

if nargin < 6
flag = [];
end
if nargin < 7
opts = [];
end
if isfield(opts,'flam'); flam = opts.flam; end
if isfield(opts,'accel'); accel = opts.accel; end
if isfield(opts,'forcefmm'); forcefmm = opts.forcefmm; end

k = chnkr.k;
nch = chnkr.nch;

assert(numel(dens) == opdims(2)*k*nch,'dens not of appropriate size')
dens = reshape(dens,opdims(2),k,nch);

[~,w] = lege.exps(k);
[~,nt] = size(targinfo.r);

fints = zeros(opdims(1)*nt,1);

% assume smooth weights are good enough

% Sequence of checks, first see ifflam is set as it supercedes
% everything, if not flam, then check to see if the fmm
% exists and whether it should be used
% The number of sources set to 200 is currently a hack,
% must be set based on opdims, accuracy, and kernel type
% considerations

imethod = 'direct';
if flam
imethod = 'flam';
elseif isa(kern,'kernel') && ~isempty(kern.fmm)
if forcefmm
imethod = 'fmm';
elseif accel
if nt > 200 || chnkr.npt > 200
imethod = 'fmm';
end
end
end

diamsrc = max(abs(chnkr.r(:)));
diamtarg = max(abs(targinfo.r(:)));
diam = max(diamsrc, diamtarg);

% flag targets that are within 1e-14 of sources
flagslf = chnk.flagself(targinfo.r, chnkr.r, 1e-14*diam);
if isempty(flagslf)
selfzero = sparse(opdims(1)*size(targinfo.r(:,:),2), ...
opdims(2)*chnkr.npt);
else
tmp = repmat((1:opdims(1))',opdims(2),1) + opdims(1)*(flagslf(1,:)-1);
flagslftarg = tmp(:);
tmp = repmat((1:opdims(2)),opdims(1),1);
tmp = tmp(:) + opdims(2)*(flagslf(2,:)-1);
flagslfsrc = tmp(:);

selfzero = sparse(flagslftarg,flagslfsrc, 1e-300, ...
opdims(1)*size(targinfo.r(:,:),2), opdims(2)*chnkr.npt);
end

if strcmpi(imethod,'direct')
% do dense version
if isempty(flag)
% nothing to ignore
for i = 1:nch
densvals = dens(:,:,i); densvals = densvals(:);
dsdtdt = sqrt(sum(abs(chnkr.d(:,:,i)).^2,1));
dsdtdt = dsdtdt(:).*w(:);
dsdtdt = repmat( (dsdtdt(:)).',opdims(2),1);
densvals = densvals.*(dsdtdt(:));
srcinfo = []; srcinfo.r = chnkr.r(:,:,i);
srcinfo.n = chnkr.n(:,:,i);
srcinfo.d = chnkr.d(:,:,i); srcinfo.d2 = chnkr.d2(:,:,i);
kernmat = kerneval(srcinfo,targinfo);

selfzeroch = selfzero(:, opdims(2)*k*(i-1) + (1:opdims(2)*k));
[isp,jsp,~] = find(selfzeroch);
linsp = isp + (jsp-1)*size(selfzeroch,1);
kernmat(linsp) = 0;

fints = fints + kernmat*densvals;
% sum(fints)
end
else
% ignore interactions in flag array
for i = 1:nch
densvals = dens(:,:,i); densvals = densvals(:);
dsdtdt = sqrt(sum(abs(chnkr.d(:,:,i)).^2,1));
dsdtdt = dsdtdt(:).*w(:);
dsdtdt = repmat( (dsdtdt(:)).',opdims(2),1);
densvals = densvals.*(dsdtdt(:));
srcinfo = []; srcinfo.r = chnkr.r(:,:,i);
srcinfo.n = chnkr.n(:,:,i);
srcinfo.d = chnkr.d(:,:,i); srcinfo.d2 = chnkr.d2(:,:,i);
kernmat = kerneval(srcinfo,targinfo);

rowkill = find(flag(:,i));
rowkill = (opdims(1)*(rowkill(:)-1)).' + (1:opdims(1)).';
kernmat(rowkill,:) = 0;

selfzeroch = selfzero(:, opdims(2)*k*(i-1) + (1:opdims(2)*k));
[isp,jsp,~] = find(selfzeroch);
linsp = isp + (jsp-1)*size(selfzeroch,1);
kernmat(linsp) = 0;

fints = fints + kernmat*densvals;
end
end
else

wts = chnkr.wts;
wts = wts(:);

if strcmpi(imethod,'flam')
xflam1 = chnkr.r(:,:);
xflam1 = repmat(xflam1,opdims(2),1);
xflam1 = reshape(xflam1,chnkr.dim,numel(xflam1)/chnkr.dim);

targinfo_flam = [];
targinfo_flam.r = repelem(targinfo.r(:,:),1,opdims(1));
if isfield(targinfo, 'd')
targinfo_flam.d = repelem(targinfo.d(:,:),1,opdims(1));
end

if isfield(targinfo, 'd2')
targinfo_flam.d2 = repelem(targinfo.d2(:,:),1,opdims(1));
end

if isfield(targinfo, 'n')
targinfo_flam.n = repelem(targinfo.n(:,:),1,opdims(1));
end

% TODO: Pull through data?

matfun = @(i,j) chnk.flam.kernbyindexr(i, j, targinfo_flam, ...,
chnkr, kerneval, opdims, selfzero);


width = max(abs(max(chnkr)-min(chnkr)))/3;
tmax = max(targinfo.r(:,:),[],2); tmin = min(targinfo.r(:,:),[],2);
wmax = max(abs(tmax-tmin));
width = max(width,wmax/3);
npxy = chnk.flam.nproxy_square(kerneval,width);
[pr,ptau,pw,pin] = chnk.flam.proxy_square_pts(npxy);

pxyfun = @(rc,rx,cx,slf,nbr,l,ctr) chnk.flam.proxyfunr(rc,rx,slf,nbr,l, ...
ctr,chnkr,kerneval,opdims,pr,ptau,pw,pin);

optsifmm=[]; optsifmm.Tmax=Inf;
F = ifmm(matfun,targinfo_flam.r,xflam1,200,1e-14,pxyfun,optsifmm);
fints = ifmm_mv(F,dens(:),matfun);
else
wts2 = repmat(wts(:).', opdims(2), 1);
sigma = wts2(:).*dens(:);
fints = kern.fmm(1e-14, chnkr, targinfo.r(:,:), sigma);
end
% delete interactions in flag array (possibly unstable approach)


if ~isempty(flag)
for i = 1:nch
densvals = dens(:,:,i); densvals = densvals(:);
dsdtdt = sqrt(sum(abs(chnkr.d(:,:,i)).^2,1));
dsdtdt = dsdtdt(:).*w(:);
dsdtdt = repmat( (dsdtdt(:)).',opdims(2),1);
densvals = densvals.*(dsdtdt(:));
srcinfo = []; srcinfo.r = chnkr.r(:,:,i);
srcinfo.n = chnkr.n(:,:,i);
srcinfo.d = chnkr.d(:,:,i); srcinfo.d2 = chnkr.d2(:,:,i);

delsmooth = find(flag(:,i));
delsmoothrow = (opdims(1)*(delsmooth(:)-1)).' + (1:opdims(1)).';
delsmoothrow = delsmoothrow(:);

targinfo_use = [];
targinfo_use.r = targinfo.r(:,delsmooth);

if isfield(targinfo, 'd')
targinfo_use.d = targinfo.d(:,delsmooth);
end

if isfield(targinfo, 'd2')
targinfo_use.d2 = targinfo.d2(:,delsmooth);
end

if isfield(targinfo, 'n')
targinfo_use.n = targinfo.n(:,delsmooth);
end

kernmat = kerneval(srcinfo,targinfo_use);

selfzeroch = selfzero(:, opdims(2)*k*(i-1) + (1:opdims(2)*k));
[isp,jsp,~] = find(selfzeroch);
linsp = isp + (jsp-1)*length(i(:));
kernmat(linsp) = 0;

fints(delsmoothrow) = fints(delsmoothrow) - kernmat*densvals;
end
end
end
% sum(fints)
end
58 changes: 58 additions & 0 deletions chunkie/+chnk/flagself.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
function flagslf = flagself(srcs, targs, tol)
% identify sources and targets pairs that are within tol (1e-14);

% todo: make this more robust by searching for all pairs that are close
% not just the closest
if nargin < 3
tol = 1e-14;
end

flagslf = [];

randangle = 2*pi*rand();
randrot = [cos(randangle), -sin(randangle); ...
sin(randangle), cos(randangle)];

srcrot = randrot*srcs(:,:);
targrot = randrot*targs(:,:);

[srcsortx, jds] = sort(srcrot(1,:));
[targsortx, ids] = sort(targrot(1,:));

binids = cell(length(srcsortx),1);

idcheck = [1;2];
for j = 1:length(srcsortx)
while (idcheck(2) < length(targsortx) && ...
targsortx(idcheck(2)) < srcsortx(j))
idcheck = idcheck+1;
end
[d, kd] = min(abs(targsortx(idcheck) - srcsortx(j)));
idclose = idcheck(kd);
if d < tol
binids{jds(j)} = [binids{jds(j)}, ids(idclose)];
end
end

[srcsorty, jds] = sort(srcrot(2,:));
[targsorty, ids] = sort(targrot(2,:));

idcheck = [1;2];
for j = 1:length(srcsorty)
while (idcheck(2) < length(targsorty) && ...
targsorty(idcheck(2)) < srcsorty(j))
idcheck = idcheck+1;
end
[d, kd] = min(abs(targsorty(idcheck) - srcsorty(j)));
idclose = idcheck(kd);
if d < tol
if ismember(ids(idclose), binids{jds(j)})
flagslf = [flagslf, [jds(j);ids(idclose)]];
end
end
end
if ~isempty(flagslf)
[~, ids] = sort(flagslf(1,:));
flagslf = flagslf(:,ids);
end
end
Loading

0 comments on commit 459ddc4

Please sign in to comment.