diff --git a/src/utils/multi_echo.jl b/src/utils/multi_echo.jl index 474cb98..1aa7796 100644 --- a/src/utils/multi_echo.jl +++ b/src/utils/multi_echo.jl @@ -4,91 +4,95 @@ """ fit_echo_linear( - phas::AbstractArray{<:AbstractFloat, 4}, - W::AbstractArray{<:AbstractFloat, 4}, - TEs::NTuple{N, Real}; + phas::AbstractArray{<:AbstractFloat, N > 1}, + W::AbstractArray{<:AbstractFloat, N > 1}, + TEs::NTuple{NT > 1, Real}; phase_offset::Bool = true - ) -> Tuple{typeof(similar(phas)){3}, [typeof(similar(phas)){3}]} + ) -> Tuple{typeof(similar(phas)){N-1}, [typeof(similar(phas)){N-1}]} Weighted least squares for multi-echo data. ### Arguments -- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase -- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel -- `TEs::NTuple{N, Real}`: echo times +- `phas::AbstractArray{<:AbstractFloat, N > 1}`: unwrapped multi-echo phase +- `W::AbstractArray{<:AbstractFloat, N > 1}`: reciprocal of error variance of voxel +- `TEs::NTuple{NT > 1, Real}`: echo times ### Keywords - `phase_offset::Bool = true`: model phase offset (`true`) ### Returns -- `typeof(similar(phas)){3}`: weighted least-squares estimate for phase -- [`typeof(similar(phas)){3}`]: weighted least-squares estimate for phase offset +- `typeof(similar(phas)){N-1}`: weighted least-squares estimate for phase +- [`typeof(similar(phas)){N-1}`]: weighted least-squares estimate for phase offset if `phase_offset = true` """ function fit_echo_linear( - phas::AbstractArray{<:AbstractFloat, 4}, - W::AbstractArray{<:AbstractFloat, 4}, - TEs::NTuple{N, Real}; + phas::AbstractArray{<:AbstractFloat, N}, + W::AbstractArray{<:AbstractFloat, N}, + TEs::NTuple{NT, Real}; phase_offset::Bool = true -) where {N} - p = tzero(phas, size(phas)[1:3]) +) where {N, NT} + N > 1 || throw(ArgumentError("array must contain echoes in last dimension")) + NT > 1 || throw(ArgumentError("data must be multi-echo")) + + p = similar(phas, size(phas)[1:N-1]) if !phase_offset return fit_echo_linear!(p, phas, W, TEs) else - return fit_echo_linear!(p, tzero(p), phas, W, TEs) + return fit_echo_linear!(p, similar(p), phas, W, TEs) end end """ fit_echo_linear!( - p::AbstractArray{<:AbstractFloat, 3}, - phas::AbstractArray{<:AbstractFloat, 4}, - W::AbstractArray{<:AbstractFloat, 4}, - TEs::NTuple{N, Real} + p::AbstractArray{<:AbstractFloat, N}, + phas::AbstractArray{<:AbstractFloat, M > 1}, + W::AbstractArray{<:AbstractFloat, M > 1}, + TEs::NTuple{NT > 1, Real} ) -> p Weighted least squares for multi-echo data (phase offset = 0). ### Arguments -- `p::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase -- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase -- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel -- `TEs::NTuple{N, Real}`: echo times +- `p::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase +- `phas::AbstractArray{<:AbstractFloat, M > 1}`: unwrapped multi-echo phase +- `W::AbstractArray{<:AbstractFloat, M > 1}`: reciprocal of error variance of voxel +- `TEs::NTuple{NT > 1, Real}`: echo times ### Returns - `p`: weighted least-squares estimate for phase """ function fit_echo_linear!( - p::AbstractArray{Tp, 3}, - phas::AbstractArray{Tphas, 4}, - W::AbstractArray{TW, 4}, + p::AbstractArray{Tp, N}, + phas::AbstractArray{Tphas, M}, + W::AbstractArray{TW, M}, TEs::NTuple{NT, Real} -) where {Tp<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, NT} - nx, ny, nz, nt = size(phas) +) where {Tp<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, N, M, NT} + M > 1 || throw(ArgumentError("array must contain echoes in last dimension")) + NT > 1 || throw(ArgumentError("data must be multi-echo")) + + size(phas, M) == NT || throw(DimensionMismatch()) + size(W) == size(phas) || throw(DimensionMismatch()) + length(p) == length(phas) ÷ NT || throw(DimensionMismatch()) - nt == NT || throw(DimensionMismatch()) - size(p) == (nx, ny, nz) || throw(DimensionMismatch()) - size(W) == size(phas) || throw(DimensionMismatch()) + vphas = reshape(phas, :, NT) + vW = reshape(W, :, NT) + vp = vec(p) T = promote_type(Tp, Tphas, TW) tes = convert.(T, TEs) - _zeroT = zero(T) _zeroTp = zero(Tp) - @threads for k in 1:nz - @inbounds for j in 1:ny - for i in 1:nx - num = _zeroT - den = _zeroT - for t in Base.OneTo(NT) - w = W[i,j,k,t] * W[i,j,k,t] * tes[t] - num = muladd(w, phas[i,j,k,t], num) - den = muladd(w, tes[t], den) - end - p[i,j,k] = iszero(den) ? _zeroTp : num * inv(den) - end + @inbounds @batch for I in eachindex(vp) + w = vW[I,1] * vW[I,1] * tes[1] + den = w * tes[1] + num = w * vphas[I,1] + for t in 2:NT + w = vW[I,t] * vW[I,t] * tes[t] + den = muladd(w, tes[t], den) + num = muladd(w, vphas[I,t], num) end + vp[I] = iszero(den) ? _zeroTp : num * inv(den) end return p @@ -96,75 +100,91 @@ end """ fit_echo_linear!( - p::AbstractArray{<:AbstractFloat, 3}, - p0::AbstractArray{<:AbstractFloat, 3}, - phas::AbstractArray{<:AbstractFloat, 4}, - W::AbstractArray{<:AbstractFloat, 4}, - TEs::NTuple{N, Real} + p::AbstractArray{<:AbstractFloat, N}, + p0::AbstractArray{<:AbstractFloat, N}, + phas::AbstractArray{<:AbstractFloat, M > 1}, + W::AbstractArray{<:AbstractFloat, M > 1}, + TEs::NTuple{NT > 1, Real} ) -> (p, p0) Weighted least squares for multi-echo data (estimate phase offset). ### Arguments -- `p::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase -- `p0::AbstractArray{<:AbstractFloat, 3}`: weighted least-squares estimate for phase offset -- `phas::AbstractArray{<:AbstractFloat, 4}`: unwrapped multi-echo phase -- `W::AbstractArray{<:AbstractFloat, 4}`: reciprocal of error variance of voxel -- `TEs::NTuple{N, Real}`: echo times +- `p::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase +- `p0::AbstractArray{<:AbstractFloat, N}`: weighted least-squares estimate for phase offset +- `phas::AbstractArray{<:AbstractFloat, M > 1}`: unwrapped multi-echo phase +- `W::AbstractArray{<:AbstractFloat, M > 1}`: reciprocal of error variance of voxel +- `TEs::NTuple{NT > 1, Real}`: echo times ### Returns - `p`: weighted least-squares estimate for phase - `p0`: weighted least-squares estimate for phase offset """ function fit_echo_linear!( - p::AbstractArray{Tp, 3}, - p0::AbstractArray{Tp0, 3}, - phas::AbstractArray{Tphas, 4}, - W::AbstractArray{TW, 4}, + p::AbstractArray{Tp, N}, + p0::AbstractArray{Tp0, N}, + phas::AbstractArray{Tphas, M}, + W::AbstractArray{TW, M}, TEs::NTuple{NT, Real} -) where {Tp<:AbstractFloat, Tp0<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, NT} - nx, ny, nz, nt = size(phas) +) where {Tp<:AbstractFloat, Tp0<:AbstractFloat, Tphas<:AbstractFloat, TW<:Real, N, M, NT} + M > 1 || throw(ArgumentError("array must contain echoes in last dimension")) + NT > 1 || throw(ArgumentError("data must be multi-echo")) - nt == NT || throw(DimensionMismatch()) - size(p) == (nx, ny, nz) || throw(DimensionMismatch()) - size(p0) == (nx, ny, nz) || throw(DimensionMismatch()) - size(W) == size(phas) || throw(DimensionMismatch()) + size(phas, M) == NT || throw(DimensionMismatch()) + size(W) == size(phas) || throw(DimensionMismatch()) + length(p) == length(phas) ÷ NT || throw(DimensionMismatch()) + length(p0) == length(phas) ÷ NT || throw(DimensionMismatch()) + + vphas = reshape(phas, :, NT) + vW = reshape(W, :, NT) + vp0 = vec(p0) + vp = vec(p) T = promote_type(Tp, Tp0, Tphas, TW) tes = convert.(T, TEs) - _zeroT = zero(T) _zeroTp = zero(Tp) + _zeroTp0 = zero(Tp0) + + @inbounds @batch for I in eachindex(vp) + w = vW[I,1] + x = w * tes[1] + y = w * vphas[I,1] + for t in 2:NT + w += vW[I,t] + x = muladd(vW[I,t], tes[t], x) + y = muladd(vW[I,t], vphas[I,t], y) + end + + if iszero(w) + vp[I] = _zeroTp + vp0[I] = _zeroTp0 + continue + end + + w = inv(w) + x *= w + y *= w + + xx = tes[1] - x + yy = vphas[I,1] - y + ww = vW[I,1] * xx + num = ww * yy + den = ww * xx + for t in 2:NT + xx = tes[t] - x + yy = vphas[I,t] - y + ww = vW[I,t] * xx + num = muladd(ww, yy, num) + den = muladd(ww, xx, den) + end - @threads for k in 1:nz - @inbounds for j in 1:ny - for i in 1:nx - x = _zeroT - y = _zeroT - w = _zeroT - for t in Base.OneTo(NT) - w += W[i,j,k,t] - x = muladd(W[i,j,k,t], tes[t], x) - y = muladd(W[i,j,k,t], phas[i,j,k,t], y) - end - - w = inv(w) - x *= w - y *= w - - num = _zeroT - den = _zeroT - for t in Base.OneTo(NT) - xx = tes[t] - x - yy = phas[i,j,k,t] - y - ww = W[i,j,k,t] * xx - num = muladd(ww, yy, num) - den = muladd(ww, xx, den) - end - - p[i,j,k] = iszero(den) ? _zeroTp : num * inv(den) - p0[i,j,k] = y - p[i,j,k] * x - end + if !iszero(den) + vp[I] = num * inv(den) + vp0[I] = y - vp[I] * x + else + vp[I] = _zeroTp + vp0[I] = _zeroTp0 end end