Skip to content

Commit

Permalink
Merge pull request #38 from MagneticParticleImaging/nh/rowNormScaling
Browse files Browse the repository at this point in the history
Improve Weighting and add row norm weighting
  • Loading branch information
nHackel authored Jul 30, 2024
2 parents 9a40108 + 5b813e9 commit 7ff3231
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module MPIRecoKernelAbstractionsExt

using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares
using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares, MPIReco.LinearOperatorCollection
using KernelAbstractions, GPUArrays
using KernelAbstractions.Extras: @unroll
using Atomix
Expand Down
80 changes: 77 additions & 3 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT

# Ideally we create a DenseMultiPatchOperator on the GPU
if validSMs && validXCC && validXSS
S = adapt(arrT, stack(op.S))
S = stack(adapt.(arrT, op.S))
# We want to use Int32 for better GPU performance
xcc = Int32.(adapt(arrT, stack(op.xcc)))
xss = Int32.(adapt(arrT, stack(op.xss)))
xcc = Int32.(stack(adapt.(arrT, op.xcc)))
xss = Int32.(stack(adapt.(arrT, op.xss)))
sign = Int32.(adapt(arrT, op.sign))
RowToPatch = Int32.(adapt(arrT, op.RowToPatch))
patchToSMIdx = Int32.(adapt(arrT, op.patchToSMIdx))
Expand Down Expand Up @@ -140,4 +140,78 @@ function RegularizedLeastSquares.kaczmarz_update!(op::DenseMultiPatchOperator{T,
kernel(x, op.S, row, beta, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = size(op.xss, 1))
synchronize(backend)
return x
end

function RegularizedLeastSquares.normalize(::SystemMatrixBasedNormalization, op::OP, x) where {T, V <: AbstractGPUArray{T}, OP <: DenseMultiPatchOperator{T, V}}
weights = one(real(eltype(op)))
energy = normalize_dense_op(op, weights)
return norm(energy)^2/size(op, 2)
end

function RegularizedLeastSquares.normalize(::SystemMatrixBasedNormalization, prod::ProdOp{T, <:WeightingOp, OP}, x) where {T, V <: AbstractGPUArray{T}, OP <: DenseMultiPatchOperator{T, V}}
op = prod.B
weights = prod.A.weights
energy = normalize_dense_op(op, weights)
return norm(energy)^2/size(prod, 2)
end

function normalize_dense_op(op::DenseMultiPatchOperator{T, V}, weights) where {T, V <: AbstractGPUArray{T}}
backend = get_backend(op.S)
kernel = normalize_kernel!(backend, 256)
energy = KernelAbstractions.zeros(backend, real(eltype(op)), size(op, 1))
kernel(energy, weights, op.S, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return energy
end

# The normalization kernels are structured the same as the mul!-kernel. The multiplication with x is replaced by abs2 for the rownorm²
@kernel cpu = false inbounds = true function normalize_kernel!(energy, weights, @Const(S), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(energy)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

localIdx = @index(Local, Linear)
shared = @localmem eltype(energy) grid_stride
shared[localIdx] = zero(eltype(energy))

@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + abs2(sign * S[patch_row, xss[i, patch], smIdx])
end
@synchronize

@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]
end
s >>= 1
@synchronize
end

if localIdx == 1
energy[operator_row] = sqrt(get_kernel_weights(weights, operator_row)^2 * shared[localIdx])
end
end

@inline get_kernel_weights(weights::AbstractArray, operator_row) = weights[operator_row]
@inline get_kernel_weights(weights::Number, operator_row) = weights

function Base.hash(op::DenseMultiPatchOperator{T, V}, h::UInt64) where {T, V <: AbstractGPUArray{T}}
@warn "Hashing of GPU DenseMultiPatchOperator is inefficient"
h = hash(typeof(op), h)
h = @allowscalar hash(op.S, h)
h = hash(op.grid, h)
h = hash(op.N, h)
h = hash(op.M, h)
h = @allowscalar hash(op.RowToPatch, h)
h = @allowscalar hash(op.xcc, h)
h = @allowscalar hash(op.xss, h)
h = @allowscalar hash(op.sign, h)
h = hash(op.nPatches, h)
h = @allowscalar hash(op.patchToSMIdx, h)
end
35 changes: 25 additions & 10 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export MultiPatchReconstructionAlgorithm, MultiPatchReconstructionParameter
Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{matT} = Array
Base.@kwdef struct MultiPatchReconstructionParameter{arrT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{arrT} = Array
# File
sf::MultiMPIFile
freqFilter::F
Expand All @@ -9,15 +9,16 @@ Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:Ab
ffPosSF::FFSF = DefaultFocusFieldPositions()
solverParams::S
reg::Vector{R} = AbstractRegularization[]
# weightingType::WeightingType = WeightingType.None
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, arrT <: AbstractArray, vecT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
opParams::Union{AbstractMultiPatchOperatorParameter, ProcessResultCache{<:AbstractMultiPatchOperatorParameter},Nothing} = nothing
sf::MultiMPIFile
arrayType::Type{matT}
weights::Union{Nothing, vecT} = nothing
arrayType::Type{arrT}
ffOp::Union{Nothing, AbstractMultiPatchOperator}
ffPos::Union{Nothing,AbstractArray}
ffPosSF::Union{Nothing,AbstractArray}
Expand All @@ -40,7 +41,7 @@ function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:Abstra
ffPosSF = [vec(ffPos(SF))[l] for l=1:L, SF in reco.sf]
end

return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm{typeof(params), reco.arrayType, typeof(reco.arrayType{Float32}(undef, 0))}(params, reco.opParams, reco.sf, nothing, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{MultiPatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::MultiPatchReconstructionAlgorithm) = algo.origParam
Expand All @@ -50,7 +51,7 @@ AbstractImageReconstruction.take!(algo::MultiPatchReconstructionAlgorithm) = Bas
function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm, data::MPIFile)
#consistenceCheck(algo.sf, data)

algo.ffOp = process(algo, algo.opParams, data, algo.freqs)
algo.ffOp, algo.weights = process(algo, algo.opParams, data, algo.freqs, algo.params.reco.weightingParams)

result = process(algo, algo.params, data, algo.freqs)

Expand All @@ -65,7 +66,7 @@ function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorith
Base.put!(algo.output, result)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}}, f::MPIFile, frequencies::Vector{CartesianIndex{2}}) where OP <: AbstractMultiPatchOperatorParameter
function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}}, f::MPIFile, frequencies::Vector{CartesianIndex{2}}, weightingParams) where OP <: AbstractMultiPatchOperatorParameter
ffPos_ = ffPos(f)
periodsSortedbyFFPos = unflattenOffsetFieldShift(ffPos_)
idxFirstPeriod = getindex.(periodsSortedbyFFPos,1)
Expand All @@ -77,7 +78,11 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, Proc
end

result = process(typeof(algo), params, algo.sf, frequencies, gradient, ffPos_, algo.ffPosSF)
return adapt(algo.arrayType, result)
# Kinda of hacky. MultiPatch parameters don't map nicely to the SinglePatch inspired pre, reco, post structure
# Have to create weights before ffop is (potentially) moved to GPU, as GPU arrays don't have efficient hash implementations
# Which makes this process expensive to cache
weights = process(typeof(algo), weightingParams, frequencies, result, nothing, algo.arrayType)
return adapt(algo.arrayType, result), weights
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{A, ProcessResultCache{<:A}}, f::MPIFile, args...) where A <: AbstractMPIPreProcessingParameters
Expand Down Expand Up @@ -121,9 +126,19 @@ function process(::Type{<:MultiPatchReconstructionAlgorithm}, params::ExternalPr
end

function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::AbstractArray)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams)
weights = process(algo, params.weightingParams, u, WeightingType(params.weightingParams))

solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams, weights = weights)

result = process(algo, solver, u)

return gridresult(result, algo.ffOp.grid, algo.sf)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::MeasurementBasedWeighting) where W<:AbstractWeightingParameters
return process(typeof(algo), params, algo.freqs, algo.ffOp, u, algo.arrayType)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::SystemMatrixBasedWeighting) where W<:AbstractWeightingParameters
return algo.weights
end
25 changes: 14 additions & 11 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchPeriodicMotion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export PeriodicMotionPreProcessing, PeriodicMotionReconstructionParameter
Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrectionParameters} <: AbstractMPIPreProcessingParameters{BG}
Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrectionParameters, W <: AbstractWeightingParameters} <: AbstractMPIPreProcessingParameters{BG}
# Periodic Motion
frames::Union{Nothing, UnitRange{Int64}, Vector{Int64}} = nothing
alpha::Float64 = 3.0
Expand All @@ -10,21 +10,21 @@ Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrecti
sf::MultiMPIFile
tfCorrection::Bool = false
bgParams::BG = NoBackgroundCorrectionParameters()
# weightingType::WeightingType = WeightingType.None
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef struct PeriodicMotionReconstructionParameter{F<:AbstractFrequencyFilterParameter, S<:AbstractSolverParameters} <: AbstractMultiPatchReconstructionParameters
Base.@kwdef struct PeriodicMotionReconstructionParameter{F<:AbstractFrequencyFilterParameter, S<:AbstractSolverParameters, R <: AbstractRegularization, arrT <: AbstractArray} <: AbstractMultiPatchReconstructionParameters
sf::MultiMPIFile
freqFilter::F
solverParams::S
λ::Float32
# weightingType::WeightingType = WeightingType.None
reg::Vector{R} = AbstractRegularization[]
arrayType::Type{arrT} = Array
end

function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:PeriodicMotionPreProcessing,<:PeriodicMotionReconstructionParameter,<:AbstractMPIPostProcessingParameters})
reco = params.reco
freqs = process(MultiPatchReconstructionAlgorithm, reco.freqFilter, reco.sf)
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, Array, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm{typeof(params), reco.arrayType, typeof(reco.arrayType{Float32}(undef, 0))}(params, nothing, reco.sf, nothing, reco.arrayType, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
end

function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm{MultiPatchParameters{PT, R, T}}, data::MPIFile) where {R, T, PT <: PeriodicMotionPreProcessing}
Expand All @@ -43,8 +43,9 @@ end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}},
f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing) where OP <: PeriodicMotionPreProcessing
uReco, ffOp = process(typeof(algo), params, f, algo.sf, frequencies)
uReco, ffOp, weights = process(typeof(algo), params, f, algo.sf, frequencies)
algo.ffOp = adapt(algo.arrayType, ffOp)
algo.weights = adapt(algo.arrayType, weights)
return adapt(algo.arrayType, uReco)
end

Expand All @@ -67,26 +68,28 @@ function process(algoT::Type{<:MultiPatchReconstructionAlgorithm}, params::Perio
resortedInd[i,:] = unflattenOffsetFieldShift(ffPos_)[i][1:p]
end

# Can't adapt data here because it might be used in background correction
ffOp = MultiPatchOperator(sf, frequencies,
#indFFPos=resortedInd[:,1], unused keyword
FFPos=ffPos_[:,resortedInd[:,1]], mapping=mapping,
FFPosSF=ffPos_[:,resortedInd[:,1]], bgCorrection = false, tfCorrection = params.tfCorrection)

return uReco, ffOp
weights = process(algoT, params.weightingParams, frequencies, ffOp, nothing, Array)
return uReco, ffOp, weights
end

function process(algoT::Type{<:MultiPatchReconstructionAlgorithm},
params::PeriodicMotionPreProcessing{SimpleExternalBackgroundCorrectionParameters}, f::MPIFile, sf::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
# Foreground
fgParams = fromKwargs(PeriodicMotionPreProcessing; toKwargs(params)..., bgParams = NoBackgroundCorrectionParameters())
result, ffOp = process(algoT, fgParams, f, sf, frequencies)
result, ffOp, weights = process(algoT, fgParams, f, sf, frequencies)
# Background
bgParams = fromKwargs(ExternalPreProcessedBackgroundCorrectionParameters; toKwargs(params)..., bgParams = params.bgParams, spectralLeakageCorrection=true)
return process(algoT, bgParams, result, frequencies), ffOp
return process(algoT, bgParams, result, frequencies), ffOp, weights
end

function process(algo::MultiPatchReconstructionAlgorithm, params::PeriodicMotionReconstructionParameter, u::Array)
solver = LeastSquaresParameters(S = algo.ffOp, reg = [L2Regularization(params.λ)], solverParams = params.solverParams)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams, weights = algo.weights)

result = process(algo, solver, u)

Expand Down
28 changes: 20 additions & 8 deletions src/Algorithms/SinglePatchAlgorithms/SinglePatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
Base.@kwdef struct SinglePatchReconstructionParameter{L<:AbstractSystemMatrixLoadingParameter, SL<:AbstractLinearSolver,
matT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
arrT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
# File
sf::MPIFile
sfLoad::Union{L, ProcessResultCache{L}}
arrayType::Type{matT} = Array
arrayType::Type{arrT} = Array
# Solver
solverParams::SP
reg::Vector{R} = AbstractRegularization[]
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, SM, arrT <: AbstractArray, vecT <: arrT} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
sf::Union{MPIFile, Vector{MPIFile}}
S::AbstractArray
arrayType::Type{matT}
S::SM
weights::Union{Nothing, vecT} = nothing
arrayType::Type{arrT}
grid::RegularGridPositions
freqs::Vector{CartesianIndex{2}}
output::Channel{Any}
Expand All @@ -26,7 +27,8 @@ function SinglePatchReconstruction(params::SinglePatchParameters{<:AbstractMPIPr
end
function SinglePatchReconstructionAlgorithm(params::SinglePatchParameters{<:AbstractMPIPreProcessingParameters, R, PT}) where {R<:AbstractSinglePatchReconstructionParameters, PT <:AbstractMPIPostProcessingParameters}
freqs, S, grid, arrayType = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, arrayType, grid, freqs, Channel{Any}(Inf))
weights = prepareWeights(params.reco, freqs, S)
return SinglePatchReconstructionAlgorithm{typeof(params), typeof(S), arrayType, typeof(arrayType{real(eltype(S))}(undef, 0))}(params, params.reco.sf, S, weights, arrayType, grid, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{SinglePatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::SinglePatchReconstructionAlgorithm) = algo.params
Expand All @@ -36,6 +38,9 @@ function prepareSystemMatrix(reco::SinglePatchReconstructionParameter{L,S}) wher
return freqs, sf, grid, reco.arrayType
end

function prepareWeights(reco::SinglePatchReconstructionParameter{L,S,arrT,SP,R,W}, freqs, sf) where {L, S, arrT, SP, R, W<:AbstractWeightingParameters}
return process(AbstractMPIRecoAlgorithm, reco.weightingParams, freqs, sf, nothing, reco.arrayType)
end

AbstractImageReconstruction.take!(algo::SinglePatchReconstructionAlgorithm) = Base.take!(algo.output)

Expand All @@ -51,7 +56,7 @@ end


function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter, u)
weights = adapt(algo.arrayType, process(algo, params.weightingParams, u))
weights = process(algo, params.weightingParams, u, WeightingType(params.weightingParams))

B = getLinearOperator(algo, params)

Expand All @@ -62,7 +67,14 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchRe
return gridresult(result, algo.grid, algo.sf)
end

process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u) where {W <: Union{ChannelWeightingParameters, WhiteningWeightingParameters}} = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))
function process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::MeasurementBasedWeighting) where W<:AbstractWeightingParameters
return process(typeof(algo), params, algo.freqs, algo.S, u, algo.arrayType)
end


function process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::SystemMatrixBasedWeighting) where W<:AbstractWeightingParameters
return algo.weights
end

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:DenseSystemMatixLoadingParameter, S}) where {S}
return nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function process(algo::T, params::SinglePatchParameters, data::MPIFile, frequenc
end

function AbstractImageReconstruction.put!(algo::AbstractSinglePatchReconstructionAlgorithm, data)
consistenceCheck(algo.sf, data)
#consistenceCheck(algo.sf, data)

result = process(algo, algo.params, data, algo.freqs)

Expand Down
Loading

0 comments on commit 7ff3231

Please sign in to comment.