Skip to content

Commit

Permalink
Add multi-threaded map
Browse files Browse the repository at this point in the history
  • Loading branch information
kamesy committed Aug 26, 2022
1 parent 93399a7 commit eb0291c
Show file tree
Hide file tree
Showing 14 changed files with 91 additions and 85 deletions.
4 changes: 2 additions & 2 deletions src/bgremove/ismv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ function _ismv!(
= one(eltype(s)) - sqrt(eps(eltype(s)))

# erode mask
s = _tcopyto!(s, m) # in-place type conversion, reuse smv var
m0 = _tcopyto!(m0, m)
s = tcopyto!(s, m) # in-place type conversion, reuse smv var
m0 = tcopyto!(m0, m)

= mul!(F̂, P, s)
@inbounds @batch for I in eachindex(F̂)
Expand Down
10 changes: 5 additions & 5 deletions src/bgremove/lbv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ function _lbv!(
flc = fl
else
mc = similar(mask, szc)
mc = _tcopyto!(mc, @view(mask[Rc]))
mc = tcopyto!(mc, @view(mask[Rc]))

if N == 3
fc = @view(f[Rc])
flc = similar(fl, szc)
flc = _tcopyto!(flc, @view(fl[Rc]))
flc = tcopyto!(flc, @view(fl[Rc]))
else
fc = @view(f[Rc,:])
flc = similar(fl, (szc..., size(fl, 4)))
flc = _tcopyto!(flc, @view(fl[Rc,:]))
flc = tcopyto!(flc, @view(fl[Rc,:]))
end
end

Expand Down Expand Up @@ -147,9 +147,9 @@ function _lbv!(

if _crop
if N == 3
_tcopyto!(@view(fl[Rc]), flc)
tcopyto!(@view(fl[Rc]), flc)
else
_tcopyto!(@view(fl[Rc,:]), flc)
tcopyto!(@view(fl[Rc,:]), flc)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/bgremove/pdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function _pdf!(
# pre-compute mask*weights
if W === nothing
# no weights
MW = _tcopyto!(MW, m)
MW = tcopyto!(MW, m)

elseif ndims(W) == 3
# same weights for all echoes
Expand Down
6 changes: 3 additions & 3 deletions src/bgremove/sharp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function _sharp!(
= one(eltype(s)) - sqrt(eps(eltype(s)))

# erode mask
s = _tcopyto!(s, m) # in-place type conversion, reuse smv var
s = tcopyto!(s, m) # in-place type conversion, reuse smv var

= mul!(F̂, P, s)
@inbounds @batch for I in eachindex(F̂)
Expand Down Expand Up @@ -232,7 +232,7 @@ function _sharp!(
= one(eltype(s)) - sqrt(eps(eltype(s)))

# fft of original mask
s = _tcopyto!(s, mr) # in-place type conversion
s = tcopyto!(s, mr) # in-place type conversion
= mul!(M̂, P, s)

@inbounds for (i, r) in enumerate(rs)
Expand Down Expand Up @@ -344,7 +344,7 @@ function _sharp!(
= one(eltype(s)) - sqrt(eps(eltype(s)))

# fft of original mask
s = _tcopyto!(s, mr) # in-place type conversion
s = tcopyto!(s, mr) # in-place type conversion
= mul!(M̂, P, s)

@inbounds for (i, r) in enumerate(rs)
Expand Down
2 changes: 1 addition & 1 deletion src/inversion/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ function _kdiv_ikernel!(

elseif reg == :laplacian
Γ = _laplace_kernel!(Γ, F, f, vsz, P)
Γ = _tcopyto!(abs2, Γ, Γ)
Γ = tmap!(abs2, Γ)
end

@inbounds @batch for I in eachindex(D)
Expand Down
2 changes: 1 addition & 1 deletion src/inversion/nltv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ function _nltv!(
end

if W !== nothing
= _tcopyto!(F̂, X̂) # real ifft overwrites input
= tcopyto!(F̂, X̂) # real ifft overwrites input
end

xp = mul!(xp, iP, X̂)
Expand Down
2 changes: 1 addition & 1 deletion src/inversion/tv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ function _tv!(
end

if W !== nothing
= _tcopyto!(F̂, X̂) # real ifft overwrites input
= tcopyto!(F̂, X̂) # real ifft overwrites input
end

xp = mul!(xp, iP, X̂)
Expand Down
14 changes: 7 additions & 7 deletions src/utils/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function _dipole_kernel!(
d = _dipole_kernel!(d, sz, vsz, bdir, :i, shift=true)

= mul!(D̂, P, d)
D = _tcopyto!(real, D, D̂)
D = tmap!(real, D, D̂)

return D
end
Expand All @@ -152,7 +152,7 @@ function _dipole_kernel!(
= _dipole_kernel!(D̂, sz, vsz, bdir, :i, shift=true)

= P*
D = _tcopyto!(real, D, D̂)
D = tmap!(real, D, D̂)

return D
end
Expand Down Expand Up @@ -321,7 +321,7 @@ function _smv_kernel!(

# normalizing
a = inv(sum(s))
s = _tcopyto!(x -> a*x, s, s)
s = tmap!(x -> a*x, s)

return s
end
Expand All @@ -341,7 +341,7 @@ function _smv_kernel!(

# fft, discard imaginary (even function -> imag = 0), and normalize
= mul!(Ŝ, P, s)
S = _tcopyto!(x -> a*real(x), S, Ŝ)
S = tmap!(x -> a*real(x), S, Ŝ)

return S
end
Expand All @@ -360,7 +360,7 @@ function _smv_kernel!(

# fft, discard imaginary (even function -> imag = 0), and normalize
= P*
S = _tcopyto!(x -> a*real(x), S, Ŝ)
S = tmap!(x -> a*real(x), S, Ŝ)

return S
end
Expand Down Expand Up @@ -498,7 +498,7 @@ function _laplace_kernel!(
Δ = _laplace_kernel!(Δ, vsz, negative=negative, shift=true)

= mul!(L̂, P, Δ)
L = _tcopyto!(real, L, L̂)
L = tmap!(real, L, L̂)

return L
end
Expand All @@ -513,7 +513,7 @@ function _laplace_kernel!(
= _laplace_kernel!(L̂, vsz, negative=negative, shift=true)

= P*
L = _tcopyto!(real, L, L̂)
L = tmap!(real, L, L̂)

return L
end
Expand Down
22 changes: 4 additions & 18 deletions src/utils/lsmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
function LSMRWorkspace(x::AbstractArray, A, b::AbstractArray)
T = typeof(one(eltype(b)) / one(eltype(A)))

u = isa(b, Array) ? tcopy(b) : copy(b)
u = tcopy(b)

v = similar(x, T)
h = similar(x, T)
Expand Down Expand Up @@ -163,23 +163,9 @@ function lsmr!(
cbar = one(Tr)
sbar = zero(Tr)

if isa(x, Array)
tfill!(x, 0)
else
fill!(x, 0)
end

if isa(h, Array) && isa(v, Array)
_tcopyto!(h, v)
else
copyto!(h, v)
end

if isa(hbar, Array)
tfill!(hbar, 0)
else
fill!(hbar, 0)
end
tfill!(x, 0)
tfill!(hbar, 0)
tcopyto!(h, v)

# Initialize variables for estimation of ||r||.
βdd = β
Expand Down
10 changes: 5 additions & 5 deletions src/utils/poisson_solver/mgpcg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ function mgpcg!(
maxlevel = maxlevel
)

q = tzero(b, size(A))
p = similar(b, size(A))
q = similar(b, size(A))

M.workspace.x[1] = q
M.workspace.x[1] = tfill!(q, zero(T))

if N == 3
M.workspace.b[1] = tcopy(b)
Expand All @@ -76,12 +76,12 @@ function mgpcg!(
M.workspace.b[1] = similar(b, size(A))

for t in axes(b, 4)
_tcopyto!(xt, @view(x[:,:,:,t]))
_tcopyto!(M.workspace.b[1], @view(b[:,:,:,t]))
tcopyto!(xt, @view(x[:,:,:,t]))
tcopyto!(M.workspace.b[1], @view(b[:,:,:,t]))

xt = _mgpcg!(xt, M, p, cycle, ncycles, atol, rtol, maxit, verbose)

_tcopyto!(@view(x[:,:,:,t]), xt)
tcopyto!(@view(x[:,:,:,t]), xt)
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/utils/poisson_solver/multigrid/transfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ end
function restrict(interior::AbstractArray{Bool, 3})
sz = size(interior)
szc = restrict_size(sz)
return restrict!(tzero(interior, szc), interior)
interiorc = tfill!(similar(interior, szc), zero(Bool))
return restrict!(interiorc, interior)
end

function restrict!(mc::AbstractArray{Bool, 3}, m::AbstractArray{Bool, 3})
Expand Down
14 changes: 7 additions & 7 deletions src/utils/poisson_solver/poisson_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ function solve_poisson_mgpcg!(
uc = u
else
mc = similar(mask, szc)
mc = _tcopyto!(mc, @view(mask[Rc]))
mc = tcopyto!(mc, @view(mask[Rc]))

if N == 3
d2uc = @view(d2u[Rc])
uc = similar(u, szc)
uc = _tcopyto!(uc, @view(u[Rc]))
uc = tcopyto!(uc, @view(u[Rc]))
else
d2uc = @view(d2u[Rc,:])
uc = similar(u, (szc..., size(u, 4)))
uc = _tcopyto!(uc, @view(u[Rc,:]))
uc = tcopyto!(uc, @view(u[Rc,:]))
end
end

Expand All @@ -66,9 +66,9 @@ function solve_poisson_mgpcg!(

if _crop
if N == 3
_tcopyto!(@view(u[Rc]), uc)
tcopyto!(@view(u[Rc]), uc)
else
_tcopyto!(@view(u[Rc,:]), uc)
tcopyto!(@view(u[Rc,:]), uc)
end
end

Expand Down Expand Up @@ -160,7 +160,7 @@ function solve_poisson_fft!(
iP = inv(P)
d2û = P*d2u
else
d2û = _tcopyto!(similar(d2u, complex(eltype(d2u))), d2u)
d2û = tcopyto!(similar(d2u, complex(eltype(d2u))), d2u)
P = plan_fft!(d2û, 1:3)
iP = inv(P)
d2û = P*d2û
Expand Down Expand Up @@ -188,7 +188,7 @@ function solve_poisson_fft!(
u = mul!(u, iP, d2û)
else
d2û = iP*d2û
u = _tcopyto!(real, u, d2û)
u = tmap!(real, u, d2û)
end

return u
Expand Down
18 changes: 10 additions & 8 deletions src/utils/r2star.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ function r2star_ll(
size(mag, N) == NT || throw(DimensionMismatch())
mask === nothing || length(mask) == length(mag) ÷ NT || throw(DimensionMismatch())

r2s = tzero(mag, size(mag)[1:N-1])
r2s = similar(mag, size(mag)[1:N-1])
r2s = tfill!(r2s, zero(T))

vmag = reshape(mag, :, NT)
vr2s = vec(r2s)

A = qr(Matrix{T}([-[TEs...] ones(NT)]))

if mask === nothing
b = tcopy(log, transpose(vmag))

b = tmap(log, transpose(vmag))
= ldiv!(A, b)

@inbounds @batch for I in eachindex(vr2s)
vr2s[I] = x̂[1,I]
end
Expand Down Expand Up @@ -110,7 +109,8 @@ function r2star_arlo(
all(()(TEs[2]-TEs[1]), TEs[2:end].-TEs[1:end-1]) ||
throw(DomainError("ARLO requires equidistant echoes"))

r2s = tzero(mag, size(mag)[1:N-1])
r2s = similar(mag, size(mag)[1:N-1])
r2s = tfill!(r2s, zero(T))

vmag = reshape(mag, :, NT)
vr2s = vec(r2s)
Expand Down Expand Up @@ -238,8 +238,9 @@ function r2star_crsi(
sigma === nothing || NR == N-1 || throw(DimensionMismatch())
M > 0 || throw(ArgumentError("interpolation factor M must be greater than 0"))

P = tcopy(x -> x*x, mag)
r2s = tzero(mag, size(mag)[1:N-1])
P = tmap(x -> x*x, mag)
r2s = similar(mag, size(mag)[1:N-1])
r2s = tfill!(r2s, zero(T))

vP = reshape(P, :, NT)
vr2s = vec(r2s)
Expand Down Expand Up @@ -376,7 +377,8 @@ function r2star_numart2s(
size(mag, N) == NT || throw(DimensionMismatch())
mask === nothing || length(mask) == length(mag) ÷ NT || throw(DimensionMismatch())

r2s = tzero(mag, size(mag)[1:N-1])
r2s = similar(mag, size(mag)[1:N-1])
r2s = tfill!(r2s, zero(T))

vmag = reshape(mag, :, NT)
vr2s = vec(r2s)
Expand Down
Loading

0 comments on commit eb0291c

Please sign in to comment.