Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add row norm scaling #38

Merged
merged 10 commits into from
Jul 30, 2024
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module MPIRecoKernelAbstractionsExt

using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares
using MPIReco, MPIReco.Adapt, MPIReco.LinearAlgebra, MPIReco.RegularizedLeastSquares, MPIReco.LinearOperatorCollection
using KernelAbstractions, GPUArrays
using KernelAbstractions.Extras: @unroll
using Atomix
Expand Down
80 changes: 77 additions & 3 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ function Adapt.adapt_structure(::Type{arrT}, op::MultiPatchOperator) where {arrT

# Ideally we create a DenseMultiPatchOperator on the GPU
if validSMs && validXCC && validXSS
S = adapt(arrT, stack(op.S))
S = stack(adapt.(arrT, op.S))
# We want to use Int32 for better GPU performance
xcc = Int32.(adapt(arrT, stack(op.xcc)))
xss = Int32.(adapt(arrT, stack(op.xss)))
xcc = Int32.(stack(adapt.(arrT, op.xcc)))
xss = Int32.(stack(adapt.(arrT, op.xss)))
sign = Int32.(adapt(arrT, op.sign))
RowToPatch = Int32.(adapt(arrT, op.RowToPatch))
patchToSMIdx = Int32.(adapt(arrT, op.patchToSMIdx))
Expand Down Expand Up @@ -140,4 +140,78 @@ function RegularizedLeastSquares.kaczmarz_update!(op::DenseMultiPatchOperator{T,
kernel(x, op.S, row, beta, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = size(op.xss, 1))
synchronize(backend)
return x
end

function RegularizedLeastSquares.normalize(::SystemMatrixBasedNormalization, op::OP, x) where {T, V <: AbstractGPUArray{T}, OP <: DenseMultiPatchOperator{T, V}}
weights = one(real(eltype(op)))
energy = normalize_dense_op(op, weights)
return norm(energy)^2/size(op, 2)
end

function RegularizedLeastSquares.normalize(::SystemMatrixBasedNormalization, prod::ProdOp{T, <:WeightingOp, OP}, x) where {T, V <: AbstractGPUArray{T}, OP <: DenseMultiPatchOperator{T, V}}
op = prod.B
weights = prod.A.weights
energy = normalize_dense_op(op, weights)
return norm(energy)^2/size(prod, 2)
end

function normalize_dense_op(op::DenseMultiPatchOperator{T, V}, weights) where {T, V <: AbstractGPUArray{T}}
backend = get_backend(op.S)
kernel = normalize_kernel!(backend, 256)
energy = KernelAbstractions.zeros(backend, real(eltype(op)), size(op, 1))
kernel(energy, weights, op.S, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (256, size(op, 1)))
synchronize(backend)
return energy
end

# The normalization kernels are structured the same as the mul!-kernel. The multiplication with x is replaced by abs2 for the rownorm²
@kernel cpu = false inbounds = true function normalize_kernel!(energy, weights, @Const(S), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx))
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(energy)(signs[patch_row, smIdx])
grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

localIdx = @index(Local, Linear)
shared = @localmem eltype(energy) grid_stride
shared[localIdx] = zero(eltype(energy))

@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + abs2(sign * S[patch_row, xss[i, patch], smIdx])
end
@synchronize

@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx + s]
end
s >>= 1
@synchronize
end

if localIdx == 1
energy[operator_row] = sqrt(get_kernel_weights(weights, operator_row)^2 * shared[localIdx])
end
end

@inline get_kernel_weights(weights::AbstractArray, operator_row) = weights[operator_row]
@inline get_kernel_weights(weights::Number, operator_row) = weights

function Base.hash(op::DenseMultiPatchOperator{T, V}, h::UInt64) where {T, V <: AbstractGPUArray{T}}
@warn "Hashing of GPU DenseMultiPatchOperator is inefficient"
h = hash(typeof(op), h)
h = @allowscalar hash(op.S, h)
h = hash(op.grid, h)
h = hash(op.N, h)
h = hash(op.M, h)
h = @allowscalar hash(op.RowToPatch, h)
h = @allowscalar hash(op.xcc, h)
h = @allowscalar hash(op.xss, h)
h = @allowscalar hash(op.sign, h)
h = hash(op.nPatches, h)
h = @allowscalar hash(op.patchToSMIdx, h)
end
35 changes: 25 additions & 10 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export MultiPatchReconstructionAlgorithm, MultiPatchReconstructionParameter
Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{matT} = Array
Base.@kwdef struct MultiPatchReconstructionParameter{arrT <: AbstractArray,F<:AbstractFrequencyFilterParameter,O<:AbstractMultiPatchOperatorParameter, S<:AbstractSolverParameters, FF<:AbstractFocusFieldPositions, FFSF<:AbstractFocusFieldPositions, R <: AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractMultiPatchReconstructionParameters
arrayType::Type{arrT} = Array
# File
sf::MultiMPIFile
freqFilter::F
Expand All @@ -9,15 +9,16 @@ Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:Ab
ffPosSF::FFSF = DefaultFocusFieldPositions()
solverParams::S
reg::Vector{R} = AbstractRegularization[]
# weightingType::WeightingType = WeightingType.None
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, arrT <: AbstractArray, vecT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
opParams::Union{AbstractMultiPatchOperatorParameter, ProcessResultCache{<:AbstractMultiPatchOperatorParameter},Nothing} = nothing
sf::MultiMPIFile
arrayType::Type{matT}
weights::Union{Nothing, vecT} = nothing
arrayType::Type{arrT}
ffOp::Union{Nothing, AbstractMultiPatchOperator}
ffPos::Union{Nothing,AbstractArray}
ffPosSF::Union{Nothing,AbstractArray}
Expand All @@ -40,7 +41,7 @@ function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:Abstra
ffPosSF = [vec(ffPos(SF))[l] for l=1:L, SF in reco.sf]
end

return MultiPatchReconstructionAlgorithm(params, reco.opParams, reco.sf, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm{typeof(params), reco.arrayType, typeof(reco.arrayType{Float32}(undef, 0))}(params, reco.opParams, reco.sf, nothing, reco.arrayType, nothing, ffPos_, ffPosSF, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{MultiPatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::MultiPatchReconstructionAlgorithm) = algo.origParam
Expand All @@ -50,7 +51,7 @@ AbstractImageReconstruction.take!(algo::MultiPatchReconstructionAlgorithm) = Bas
function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm, data::MPIFile)
#consistenceCheck(algo.sf, data)

algo.ffOp = process(algo, algo.opParams, data, algo.freqs)
algo.ffOp, algo.weights = process(algo, algo.opParams, data, algo.freqs, algo.params.reco.weightingParams)

result = process(algo, algo.params, data, algo.freqs)

Expand All @@ -65,7 +66,7 @@ function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorith
Base.put!(algo.output, result)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}}, f::MPIFile, frequencies::Vector{CartesianIndex{2}}) where OP <: AbstractMultiPatchOperatorParameter
function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}}, f::MPIFile, frequencies::Vector{CartesianIndex{2}}, weightingParams) where OP <: AbstractMultiPatchOperatorParameter
ffPos_ = ffPos(f)
periodsSortedbyFFPos = unflattenOffsetFieldShift(ffPos_)
idxFirstPeriod = getindex.(periodsSortedbyFFPos,1)
Expand All @@ -77,7 +78,11 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, Proc
end

result = process(typeof(algo), params, algo.sf, frequencies, gradient, ffPos_, algo.ffPosSF)
return adapt(algo.arrayType, result)
# Kinda of hacky. MultiPatch parameters don't map nicely to the SinglePatch inspired pre, reco, post structure
# Have to create weights before ffop is (potentially) moved to GPU, as GPU arrays don't have efficient hash implementations
# Which makes this process expensive to cache
weights = process(typeof(algo), weightingParams, frequencies, result, nothing, algo.arrayType)
return adapt(algo.arrayType, result), weights
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{A, ProcessResultCache{<:A}}, f::MPIFile, args...) where A <: AbstractMPIPreProcessingParameters
Expand Down Expand Up @@ -121,9 +126,19 @@ function process(::Type{<:MultiPatchReconstructionAlgorithm}, params::ExternalPr
end

function process(algo::MultiPatchReconstructionAlgorithm, params::MultiPatchReconstructionParameter, u::AbstractArray)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams)
weights = process(algo, params.weightingParams, u, WeightingType(params.weightingParams))

solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams, weights = weights)

result = process(algo, solver, u)

return gridresult(result, algo.ffOp.grid, algo.sf)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::MeasurementBasedWeighting) where W<:AbstractWeightingParameters
return process(typeof(algo), params, algo.freqs, algo.ffOp, u, algo.arrayType)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::SystemMatrixBasedWeighting) where W<:AbstractWeightingParameters
return algo.weights
end
25 changes: 14 additions & 11 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchPeriodicMotion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export PeriodicMotionPreProcessing, PeriodicMotionReconstructionParameter
Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrectionParameters} <: AbstractMPIPreProcessingParameters{BG}
Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrectionParameters, W <: AbstractWeightingParameters} <: AbstractMPIPreProcessingParameters{BG}
# Periodic Motion
frames::Union{Nothing, UnitRange{Int64}, Vector{Int64}} = nothing
alpha::Float64 = 3.0
Expand All @@ -10,21 +10,21 @@ Base.@kwdef struct PeriodicMotionPreProcessing{BG<:AbstractMPIBackgroundCorrecti
sf::MultiMPIFile
tfCorrection::Bool = false
bgParams::BG = NoBackgroundCorrectionParameters()
# weightingType::WeightingType = WeightingType.None
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef struct PeriodicMotionReconstructionParameter{F<:AbstractFrequencyFilterParameter, S<:AbstractSolverParameters} <: AbstractMultiPatchReconstructionParameters
Base.@kwdef struct PeriodicMotionReconstructionParameter{F<:AbstractFrequencyFilterParameter, S<:AbstractSolverParameters, R <: AbstractRegularization, arrT <: AbstractArray} <: AbstractMultiPatchReconstructionParameters
sf::MultiMPIFile
freqFilter::F
solverParams::S
λ::Float32
# weightingType::WeightingType = WeightingType.None
reg::Vector{R} = AbstractRegularization[]
arrayType::Type{arrT} = Array
end

function MultiPatchReconstructionAlgorithm(params::MultiPatchParameters{<:PeriodicMotionPreProcessing,<:PeriodicMotionReconstructionParameter,<:AbstractMPIPostProcessingParameters})
reco = params.reco
freqs = process(MultiPatchReconstructionAlgorithm, reco.freqFilter, reco.sf)
return MultiPatchReconstructionAlgorithm(params, nothing, reco.sf, Array, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
return MultiPatchReconstructionAlgorithm{typeof(params), reco.arrayType, typeof(reco.arrayType{Float32}(undef, 0))}(params, nothing, reco.sf, nothing, reco.arrayType, nothing, nothing, nothing, freqs, Channel{Any}(Inf))
end

function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorithm{MultiPatchParameters{PT, R, T}}, data::MPIFile) where {R, T, PT <: PeriodicMotionPreProcessing}
Expand All @@ -43,8 +43,9 @@ end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}},
f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing) where OP <: PeriodicMotionPreProcessing
uReco, ffOp = process(typeof(algo), params, f, algo.sf, frequencies)
uReco, ffOp, weights = process(typeof(algo), params, f, algo.sf, frequencies)
algo.ffOp = adapt(algo.arrayType, ffOp)
algo.weights = adapt(algo.arrayType, weights)
return adapt(algo.arrayType, uReco)
end

Expand All @@ -67,26 +68,28 @@ function process(algoT::Type{<:MultiPatchReconstructionAlgorithm}, params::Perio
resortedInd[i,:] = unflattenOffsetFieldShift(ffPos_)[i][1:p]
end

# Can't adapt data here because it might be used in background correction
ffOp = MultiPatchOperator(sf, frequencies,
#indFFPos=resortedInd[:,1], unused keyword
FFPos=ffPos_[:,resortedInd[:,1]], mapping=mapping,
FFPosSF=ffPos_[:,resortedInd[:,1]], bgCorrection = false, tfCorrection = params.tfCorrection)

return uReco, ffOp
weights = process(algoT, params.weightingParams, frequencies, ffOp, nothing, Array)
return uReco, ffOp, weights
end

function process(algoT::Type{<:MultiPatchReconstructionAlgorithm},
params::PeriodicMotionPreProcessing{SimpleExternalBackgroundCorrectionParameters}, f::MPIFile, sf::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
# Foreground
fgParams = fromKwargs(PeriodicMotionPreProcessing; toKwargs(params)..., bgParams = NoBackgroundCorrectionParameters())
result, ffOp = process(algoT, fgParams, f, sf, frequencies)
result, ffOp, weights = process(algoT, fgParams, f, sf, frequencies)
# Background
bgParams = fromKwargs(ExternalPreProcessedBackgroundCorrectionParameters; toKwargs(params)..., bgParams = params.bgParams, spectralLeakageCorrection=true)
return process(algoT, bgParams, result, frequencies), ffOp
return process(algoT, bgParams, result, frequencies), ffOp, weights
end

function process(algo::MultiPatchReconstructionAlgorithm, params::PeriodicMotionReconstructionParameter, u::Array)
solver = LeastSquaresParameters(S = algo.ffOp, reg = [L2Regularization(params.λ)], solverParams = params.solverParams)
solver = LeastSquaresParameters(S = algo.ffOp, reg = params.reg, solverParams = params.solverParams, weights = algo.weights)

result = process(algo, solver, u)

Expand Down
28 changes: 20 additions & 8 deletions src/Algorithms/SinglePatchAlgorithms/SinglePatchAlgorithm.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
Base.@kwdef struct SinglePatchReconstructionParameter{L<:AbstractSystemMatrixLoadingParameter, SL<:AbstractLinearSolver,
matT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
arrT <: AbstractArray, SP<:AbstractSolverParameters{SL}, R<:AbstractRegularization, W<:AbstractWeightingParameters} <: AbstractSinglePatchReconstructionParameters
# File
sf::MPIFile
sfLoad::Union{L, ProcessResultCache{L}}
arrayType::Type{matT} = Array
arrayType::Type{arrT} = Array
# Solver
solverParams::SP
reg::Vector{R} = AbstractRegularization[]
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, SM, arrT <: AbstractArray, vecT <: arrT} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
sf::Union{MPIFile, Vector{MPIFile}}
S::AbstractArray
arrayType::Type{matT}
S::SM
weights::Union{Nothing, vecT} = nothing
arrayType::Type{arrT}
grid::RegularGridPositions
freqs::Vector{CartesianIndex{2}}
output::Channel{Any}
Expand All @@ -26,7 +27,8 @@ function SinglePatchReconstruction(params::SinglePatchParameters{<:AbstractMPIPr
end
function SinglePatchReconstructionAlgorithm(params::SinglePatchParameters{<:AbstractMPIPreProcessingParameters, R, PT}) where {R<:AbstractSinglePatchReconstructionParameters, PT <:AbstractMPIPostProcessingParameters}
freqs, S, grid, arrayType = prepareSystemMatrix(params.reco)
return SinglePatchReconstructionAlgorithm(params, params.reco.sf, S, arrayType, grid, freqs, Channel{Any}(Inf))
weights = prepareWeights(params.reco, freqs, S)
return SinglePatchReconstructionAlgorithm{typeof(params), typeof(S), arrayType, typeof(arrayType{real(eltype(S))}(undef, 0))}(params, params.reco.sf, S, weights, arrayType, grid, freqs, Channel{Any}(Inf))
end
recoAlgorithmTypes(::Type{SinglePatchReconstruction}) = SystemMatrixBasedAlgorithm()
AbstractImageReconstruction.parameter(algo::SinglePatchReconstructionAlgorithm) = algo.params
Expand All @@ -36,6 +38,9 @@ function prepareSystemMatrix(reco::SinglePatchReconstructionParameter{L,S}) wher
return freqs, sf, grid, reco.arrayType
end

function prepareWeights(reco::SinglePatchReconstructionParameter{L,S,arrT,SP,R,W}, freqs, sf) where {L, S, arrT, SP, R, W<:AbstractWeightingParameters}
return process(AbstractMPIRecoAlgorithm, reco.weightingParams, freqs, sf, nothing, reco.arrayType)
end

AbstractImageReconstruction.take!(algo::SinglePatchReconstructionAlgorithm) = Base.take!(algo.output)

Expand All @@ -51,7 +56,7 @@ end


function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter, u)
weights = adapt(algo.arrayType, process(algo, params.weightingParams, u))
weights = process(algo, params.weightingParams, u, WeightingType(params.weightingParams))

B = getLinearOperator(algo, params)

Expand All @@ -62,7 +67,14 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchRe
return gridresult(result, algo.grid, algo.sf)
end

process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u) where {W <: Union{ChannelWeightingParameters, WhiteningWeightingParameters}} = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))
function process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::MeasurementBasedWeighting) where W<:AbstractWeightingParameters
return process(typeof(algo), params, algo.freqs, algo.S, u, algo.arrayType)
end


function process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u, ::SystemMatrixBasedWeighting) where W<:AbstractWeightingParameters
return algo.weights
end

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:DenseSystemMatixLoadingParameter, S}) where {S}
return nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function process(algo::T, params::SinglePatchParameters, data::MPIFile, frequenc
end

function AbstractImageReconstruction.put!(algo::AbstractSinglePatchReconstructionAlgorithm, data)
consistenceCheck(algo.sf, data)
#consistenceCheck(algo.sf, data)

result = process(algo, algo.params, data, algo.freqs)

Expand Down
Loading
Loading