Skip to content

Commit

Permalink
Add general Nd support
Browse files Browse the repository at this point in the history
  • Loading branch information
kamesy committed Apr 14, 2022
1 parent ac8e978 commit 80dd2b2
Showing 1 changed file with 114 additions and 94 deletions.
208 changes: 114 additions & 94 deletions src/utils/multi_echo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,167 +4,187 @@

"""
fit_echo_linear(
phas::AbstractArray{<:AbstractFloat, 4},
W::AbstractArray{<:AbstractFloat, 4},
TEs::NTuple{N, Real};
phas::AbstractArray{<:AbstractFloat, N > 1},
W::AbstractArray{<:AbstractFloat, N > 1},
TEs::NTuple{NT > 1, Real};
phase_offset::Bool = true
) -> Tuple{typeof(similar(phas)){3}, [typeof(similar(phas)){3}]}
) -> Tuple{typeof(similar(phas)){N-1}, [typeof(similar(phas)){N-1}]}
Weighted least squares for multi-echo data.
### Arguments
- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel
- `TEs::NTuple{N, Real}`: echo times
- `phas::AbstractArray{<:AbstractFloat, N > 1}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, N > 1}`: reciprocal of error variance of voxel
- `TEs::NTuple{NT > 1, Real}`: echo times
### Keywords
- `phase_offset::Bool = true`: model phase offset (`true`)
### Returns
- `typeof(similar(phas)){3}`: weighted least-squares estimate for phase
- [`typeof(similar(phas)){3}`]: weighted least-squares estimate for phase offset
- `typeof(similar(phas)){N-1}`: weighted least-squares estimate for phase
- [`typeof(similar(phas)){N-1}`]: weighted least-squares estimate for phase offset
if `phase_offset = true`
"""
function fit_echo_linear(
phas::AbstractArray{<:AbstractFloat, 4},
W::AbstractArray{<:AbstractFloat, 4},
TEs::NTuple{N, Real};
phas::AbstractArray{<:AbstractFloat, N},
W::AbstractArray{<:AbstractFloat, N},
TEs::NTuple{NT, Real};
phase_offset::Bool = true
) where {N}
p = tzero(phas, size(phas)[1:3])
) where {N, NT}
N > 1 || throw(ArgumentError("array must contain echoes in last dimension"))
NT > 1 || throw(ArgumentError("data must be multi-echo"))

p = similar(phas, size(phas)[1:N-1])
if !phase_offset
return fit_echo_linear!(p, phas, W, TEs)
else
return fit_echo_linear!(p, tzero(p), phas, W, TEs)
return fit_echo_linear!(p, similar(p), phas, W, TEs)
end
end

"""
fit_echo_linear!(
p::AbstractArray{<:AbstractFloat, 3},
phas::AbstractArray{<:AbstractFloat, 4},
W::AbstractArray{<:AbstractFloat, 4},
TEs::NTuple{N, Real}
p::AbstractArray{<:AbstractFloat, N},
phas::AbstractArray{<:AbstractFloat, M > 1},
W::AbstractArray{<:AbstractFloat, M > 1},
TEs::NTuple{NT > 1, Real}
) -> p
Weighted least squares for multi-echo data (phase offset = 0).
### Arguments
- `p::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase
- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel
- `TEs::NTuple{N, Real}`: echo times
- `p::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase
- `phas::AbstractArray{<:AbstractFloat, M > 1}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, M > 1}`: reciprocal of error variance of voxel
- `TEs::NTuple{NT > 1, Real}`: echo times
### Returns
- `p`: weighted least-squares estimate for phase
"""
function fit_echo_linear!(
p::AbstractArray{Tp, 3},
phas::AbstractArray{Tphas, 4},
W::AbstractArray{TW, 4},
p::AbstractArray{Tp, N},
phas::AbstractArray{Tphas, M},
W::AbstractArray{TW, M},
TEs::NTuple{NT, Real}
) where {Tp<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, NT}
nx, ny, nz, nt = size(phas)
) where {Tp<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, N, M, NT}
M > 1 || throw(ArgumentError("array must contain echoes in last dimension"))
NT > 1 || throw(ArgumentError("data must be multi-echo"))

size(phas, M) == NT || throw(DimensionMismatch())
size(W) == size(phas) || throw(DimensionMismatch())
length(p) == length(phas) ÷ NT || throw(DimensionMismatch())

nt == NT || throw(DimensionMismatch())
size(p) == (nx, ny, nz) || throw(DimensionMismatch())
size(W) == size(phas) || throw(DimensionMismatch())
vphas = reshape(phas, :, NT)
vW = reshape(W, :, NT)
vp = vec(p)

T = promote_type(Tp, Tphas, TW)
tes = convert.(T, TEs)

_zeroT = zero(T)
_zeroTp = zero(Tp)

@threads for k in 1:nz
@inbounds for j in 1:ny
for i in 1:nx
num = _zeroT
den = _zeroT
for t in Base.OneTo(NT)
w = W[i,j,k,t] * W[i,j,k,t] * tes[t]
num = muladd(w, phas[i,j,k,t], num)
den = muladd(w, tes[t], den)
end
p[i,j,k] = iszero(den) ? _zeroTp : num * inv(den)
end
@inbounds @batch for I in eachindex(vp)
w = vW[I,1] * vW[I,1] * tes[1]
den = w * tes[1]
num = w * vphas[I,1]
for t in 2:NT
w = vW[I,t] * vW[I,t] * tes[t]
den = muladd(w, tes[t], den)
num = muladd(w, vphas[I,t], num)
end
vp[I] = iszero(den) ? _zeroTp : num * inv(den)
end

return p
end

"""
fit_echo_linear!(
p::AbstractArray{<:AbstractFloat, 3},
p0::AbstractArray{<:AbstractFloat, 3},
phas::AbstractArray{<:AbstractFloat, 4},
W::AbstractArray{<:AbstractFloat, 4},
TEs::NTuple{N, Real}
p::AbstractArray{<:AbstractFloat, N},
p0::AbstractArray{<:AbstractFloat, N},
phas::AbstractArray{<:AbstractFloat, M > 1},
W::AbstractArray{<:AbstractFloat, M > 1},
TEs::NTuple{NT > 1, Real}
) -> (p, p0)
Weighted least squares for multi-echo data (estimate phase offset).
### Arguments
- `p::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase
- `p0::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase offset
- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel
- `TEs::NTuple{N, Real}`: echo times
- `p::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase
- `p0::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase offset
- `phas::AbstractArray{<:AbstractFloat, M > 1}`: unwrapped multi-echo phase
- `W::AbstractArray{<:AbstractFloat, M > 1}`: reciprocal of error variance of voxel
- `TEs::NTuple{NT > 1, Real}`: echo times
### Returns
- `p`: weighted least-squares estimate for phase
- `p0`: weighted least-squares estimate for phase offset
"""
function fit_echo_linear!(
p::AbstractArray{Tp, 3},
p0::AbstractArray{Tp0, 3},
phas::AbstractArray{Tphas, 4},
W::AbstractArray{TW, 4},
p::AbstractArray{Tp, N},
p0::AbstractArray{Tp0, N},
phas::AbstractArray{Tphas, M},
W::AbstractArray{TW, M},
TEs::NTuple{NT, Real}
) where {Tp<:AbstractFloat, Tp0<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, NT}
nx, ny, nz, nt = size(phas)
) where {Tp<:AbstractFloat, Tp0<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, N, M, NT}
M > 1 || throw(ArgumentError("array must contain echoes in last dimension"))
NT > 1 || throw(ArgumentError("data must be multi-echo"))

nt == NT || throw(DimensionMismatch())
size(p) == (nx, ny, nz) || throw(DimensionMismatch())
size(p0) == (nx, ny, nz) || throw(DimensionMismatch())
size(W) == size(phas) || throw(DimensionMismatch())
size(phas, M) == NT || throw(DimensionMismatch())
size(W) == size(phas) || throw(DimensionMismatch())
length(p) == length(phas) ÷ NT || throw(DimensionMismatch())
length(p0) == length(phas) ÷ NT || throw(DimensionMismatch())

vphas = reshape(phas, :, NT)
vW = reshape(W, :, NT)
vp0 = vec(p0)
vp = vec(p)

T = promote_type(Tp, Tp0, Tphas, TW)
tes = convert.(T, TEs)

_zeroT = zero(T)
_zeroTp = zero(Tp)
_zeroTp0 = zero(Tp0)

@inbounds @batch for I in eachindex(vp)
w = vW[I,1]
x = w * tes[1]
y = w * vphas[I,1]
for t in 2:NT
w += vW[I,t]
x = muladd(vW[I,t], tes[t], x)
y = muladd(vW[I,t], vphas[I,t], y)
end

if iszero(w)
vp[I] = _zeroTp
vp0[I] = _zeroTp0
continue
end

w = inv(w)
x *= w
y *= w

xx = tes[1] - x
yy = vphas[I,1] - y
ww = vW[I,1] * xx
num = ww * yy
den = ww * xx
for t in 2:NT
xx = tes[t] - x
yy = vphas[I,t] - y
ww = vW[I,t] * xx
num = muladd(ww, yy, num)
den = muladd(ww, xx, den)
end

@threads for k in 1:nz
@inbounds for j in 1:ny
for i in 1:nx
x = _zeroT
y = _zeroT
w = _zeroT
for t in Base.OneTo(NT)
w += W[i,j,k,t]
x = muladd(W[i,j,k,t], tes[t], x)
y = muladd(W[i,j,k,t], phas[i,j,k,t], y)
end

w = inv(w)
x *= w
y *= w

num = _zeroT
den = _zeroT
for t in Base.OneTo(NT)
xx = tes[t] - x
yy = phas[i,j,k,t] - y
ww = W[i,j,k,t] * xx
num = muladd(ww, yy, num)
den = muladd(ww, xx, den)
end

p[i,j,k] = iszero(den) ? _zeroTp : num * inv(den)
p0[i,j,k] = y - p[i,j,k] * x
end
if !iszero(den)
vp[I] = num * inv(den)
vp0[I] = y - vp[I] * x
else
vp[I] = _zeroTp
vp0[I] = _zeroTp0
end
end

Expand Down

0 comments on commit 80dd2b2

Please sign in to comment.