Skip to content

Commit

Permalink
Merge pull request #37 from MagneticParticleImaging/nh/caching
Browse files Browse the repository at this point in the history
Add caching of RecoPlans
  • Loading branch information
nHackel authored Jul 26, 2024
2 parents 7b80497 + 60ca654 commit 9a40108
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 44 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ IniFile = "83e8ac13-25f8-5344-8a64-a9f2b223428f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
LinearOperatorCollection = "a4a2c56f-fead-462a-a3ab-85921a5f2575"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
MPIFiles = "371237a9-e6c1-5201-9adb-3d8cfa78fa9f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
AbstractImageReconstruction = "a4b4fdbf-6459-4ec9-990d-77e1fa24a91b"
Expand All @@ -40,6 +41,7 @@ KernelAbstractions = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2.3"
LinearOperatorCollection = "2"
LRUCache = "1.6"
MPIFiles = "0.13, 0.14, 0.15, 0.16"
ProgressMeter = "1.2"
Reexport = "1.0"
Expand Down
8 changes: 6 additions & 2 deletions config/MultiPatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ _module = "MPIReco"
_module = "MPIReco"

[parameter.reco.opParams]
_type = "RecoPlan{RegularMultiPatchOperatorParameter}"
_module = "MPIReco"
_module = "AbstractImageReconstruction"
_type = "RecoPlan{ProcessResultCache}"

[parameter.reco.opParams.param]
_type = "RecoPlan{RegularMultiPatchOperatorParameter}"
_module = "MPIReco"

[parameter.reco.ffPosSF]
_type = "RecoPlan{DefaultFocusFieldPositions}"
Expand Down
8 changes: 6 additions & 2 deletions config/MultiPatchMotion.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ _module = "MPIReco"
_module = "MPIReco"

[parameter.pre]
_type = "RecoPlan{PeriodicMotionPreProcessing}"
_module = "MPIReco"
_module = "AbstractImageReconstruction"
_type = "RecoPlan{ProcessResultCache}"

[parameter.pre.param]
_type = "RecoPlan{PeriodicMotionPreProcessing}"
_module = "MPIReco"

[parameter.post]
_type = "RecoPlan{NoPostProcessing}"
Expand Down
49 changes: 46 additions & 3 deletions src/AlgorithmInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export planpath, plandir
plandir() = abspath(homedir(), ".mpi", "RecoPlans")

function planpath(name::AbstractString)
for dir in [joinpath(@__DIR__, "..", "config"),plandir()]
for dir in [plandir(), joinpath(@__DIR__, "..", "config")]
filename = joinpath(dir, string(name, ".toml"))
if isfile(filename)
return filename
Expand All @@ -55,12 +55,55 @@ function planpath(name::AbstractString)
throw(ArgumentError("Could not find a suitable MPI reconstruction plan with name $name.\nCustom plans can be stored in $(plandir())."))
end

const recoPlans = LRU{UInt64, RecoPlan}(maxsize = 3)

export reconstruct
function reconstruct(name::AbstractString, data::MPIFile; kwargs...)
plan = loadPlan(MPIReco, name, [AbstractImageReconstruction, MPIFiles, MPIReco, RegularizedLeastSquares])
"""
reconstruct(name::AbstractString, data::MPIFile, cache::Bool = true; kwargs...)
Perform a reconstruction with the `RecoPlan` specified by `name` and given `data`. If `cache` is `true` the reconstruction plan is cached and reused if the plan file has not changed.
Additional keyword arguments can be passed to the reconstruction plan.
`RecoPlans` can be stored in the directory `$(plandir())` or in the MPIReco package config folder. The first plan found is used. The cache considers the last modification time of the plan file.
If a keyword argument changes the structure of the plan the cache is bypassed.
The cache can be emptied with `emptyRecoCache!()`.
# Examples
```julia
julia> mdf = MPIFile("data.mdf");
julia> reconstruct("SinglePatch", mdf; solver = Kaczmarz, reg = [L2Regularization(0.3f0)], iterations = 10, frames = 1:10, ...)
```
"""
function reconstruct(name::AbstractString, data::MPIFile, cache::Bool = true; kwargs...)
plan = loadRecoPlan(name, cache; kwargs...)
setAll!(plan; kwargs...)
return reconstruct(build(plan), data)
end
function loadRecoPlan(name::AbstractString, cache::Bool; kwargs...)
planfile = AbstractImageReconstruction.planpath(MPIReco, name)

# If the user disables caching or changes the plan structure we bypass the cache
kwargValues = values(values(kwargs))
if !cache || any(val -> isa(val, RecoPlan) || isa(val, AbstractImageReconstructionParameters), kwargValues)
return loadRecoPlan(planfile)
end

key = hash(planfile, hash(mtime(planfile)))
return get!(recoPlans, key) do
loadRecoPlan(planfile)
end
end
loadRecoPlan(planfile::AbstractString) = loadPlan(planfile, [AbstractImageReconstruction, MPIFiles, MPIReco, RegularizedLeastSquares])

export emptyRecoCache!
"""
emptyRecoCache!()
Empty the cache of `RecoPlans`. This is useful if the cache is too large.
"""
emptyRecoCache!() = Base.empty!(recoPlans)

# Check if contains
isSystemMatrixBased(::T) where T <: AbstractImageReconstructionAlgorithm = recoAlgorithmTypes(T) isa SystemMatrixBasedAlgorithm
Expand Down
14 changes: 6 additions & 8 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchAlgorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Base.@kwdef struct MultiPatchReconstructionParameter{matT <: AbstractArray,F<:Ab
# File
sf::MultiMPIFile
freqFilter::F
opParams::O
opParams::Union{O, ProcessResultCache{O}}
ffPos::FF = DefaultFocusFieldPositions()
ffPosSF::FFSF = DefaultFocusFieldPositions()
solverParams::S
Expand All @@ -15,7 +15,7 @@ end
Base.@kwdef mutable struct MultiPatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractMultiPatchReconstructionAlgorithm where {P<:AbstractMultiPatchAlgorithmParameters}
params::P
# Could also do reconstruction progress meter here
opParams::Union{AbstractMultiPatchOperatorParameter, Nothing} = nothing
opParams::Union{AbstractMultiPatchOperatorParameter, ProcessResultCache{<:AbstractMultiPatchOperatorParameter},Nothing} = nothing
sf::MultiMPIFile
arrayType::Type{matT}
ffOp::Union{Nothing, AbstractMultiPatchOperator}
Expand Down Expand Up @@ -65,7 +65,7 @@ function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorith
Base.put!(algo.output, result)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::AbstractMultiPatchOperatorParameter, f::MPIFile, frequencies::Vector{CartesianIndex{2}})
function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}}, f::MPIFile, frequencies::Vector{CartesianIndex{2}}) where OP <: AbstractMultiPatchOperatorParameter
ffPos_ = ffPos(f)
periodsSortedbyFFPos = unflattenOffsetFieldShift(ffPos_)
idxFirstPeriod = getindex.(periodsSortedbyFFPos,1)
Expand All @@ -75,11 +75,9 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::AbstractMultiP
if !isnothing(algo.ffPos)
ffPos_[:] = algo.ffPos
end

ffPosSF = algo.ffPosSF

operator = MultiPatchOperator(algo.sf, frequencies; toKwargs(params)..., gradient = gradient, FFPos = ffPos_, FFPosSF = ffPosSF)
return adapt(algo.arrayType, operator)

result = process(typeof(algo), params, algo.sf, frequencies, gradient, ffPos_, algo.ffPosSF)
return adapt(algo.arrayType, result)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::Union{A, ProcessResultCache{<:A}}, f::MPIFile, args...) where A <: AbstractMPIPreProcessingParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ positions(ffPos::CustomFocusFieldPositions) = ffPos.positions

Base.@kwdef mutable struct MultiPatchParameters{PR<:AbstractMPIPreProcessingParameters,
R<:AbstractMultiPatchReconstructionParameters, PT<:AbstractMPIPostProcessingParameters} <: AbstractMultiPatchAlgorithmParameters
pre::PR
pre::Union{PR, ProcessResultCache{PR}}
reco::R
post::PT = NoPostProcessing()
end
Expand Down
24 changes: 16 additions & 8 deletions src/Algorithms/MultiPatchAlgorithms/MultiPatchPeriodicMotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,16 @@ function AbstractImageReconstruction.put!(algo::MultiPatchReconstructionAlgorith
Base.put!(algo.output, result)
end

function process(algo::MultiPatchReconstructionAlgorithm, params::PeriodicMotionPreProcessing{NoBackgroundCorrectionParameters},
f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
function process(algo::MultiPatchReconstructionAlgorithm, params::Union{OP, ProcessResultCache{OP}},
f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing) where OP <: PeriodicMotionPreProcessing
uReco, ffOp = process(typeof(algo), params, f, algo.sf, frequencies)
algo.ffOp = adapt(algo.arrayType, ffOp)
return adapt(algo.arrayType, uReco)
end

function process(algoT::Type{<:MultiPatchReconstructionAlgorithm}, params::PeriodicMotionPreProcessing{NoBackgroundCorrectionParameters},
f::MPIFile, sf::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
@info "Loading Multi Patch motion operator"
ffPos_ = ffPos(f)
motFreq = getMotionFreq(params.sf, f, params.choosePeak) ./ params.higherHarmonic
tmot = getRepetitionsOfSameState(f, motFreq, params.frames)
Expand All @@ -59,22 +67,22 @@ function process(algo::MultiPatchReconstructionAlgorithm, params::PeriodicMotion
resortedInd[i,:] = unflattenOffsetFieldShift(ffPos_)[i][1:p]
end

algo.ffOp = MultiPatchOperator(algo.sf, frequencies,
ffOp = MultiPatchOperator(sf, frequencies,
#indFFPos=resortedInd[:,1], unused keyword
FFPos=ffPos_[:,resortedInd[:,1]], mapping=mapping,
FFPosSF=ffPos_[:,resortedInd[:,1]], bgCorrection = false, tfCorrection = params.tfCorrection)

return uReco
return uReco, ffOp
end

function process(algo::MultiPatchReconstructionAlgorithm,
params::PeriodicMotionPreProcessing{SimpleExternalBackgroundCorrectionParameters}, f::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
function process(algoT::Type{<:MultiPatchReconstructionAlgorithm},
params::PeriodicMotionPreProcessing{SimpleExternalBackgroundCorrectionParameters}, f::MPIFile, sf::MPIFile, frequencies::Union{Vector{CartesianIndex{2}}, Nothing} = nothing)
# Foreground
fgParams = fromKwargs(PeriodicMotionPreProcessing; toKwargs(params)..., bgParams = NoBackgroundCorrectionParameters())
result = process(algo, fgParams, f, frequencies)
result, ffOp = process(algoT, fgParams, f, sf, frequencies)
# Background
bgParams = fromKwargs(ExternalPreProcessedBackgroundCorrectionParameters; toKwargs(params)..., bgParams = params.bgParams, spectralLeakageCorrection=true)
return process(algo, bgParams, result, frequencies)
return process(algoT, bgParams, result, frequencies), ffOp
end

function process(algo::MultiPatchReconstructionAlgorithm, params::PeriodicMotionReconstructionParameter, u::Array)
Expand Down
7 changes: 3 additions & 4 deletions src/Algorithms/SinglePatchAlgorithms/SinglePatchAlgorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Base.@kwdef struct SinglePatchReconstructionParameter{L<:AbstractSystemMatrixLoa
# Solver
solverParams::SP
reg::Vector{R} = AbstractRegularization[]
weightingParams::W = NoWeightingParameters()
weightingParams::Union{W, ProcessResultCache{W}} = NoWeightingParameters()
end

Base.@kwdef mutable struct SinglePatchReconstructionAlgorithm{P, matT <: AbstractArray} <: AbstractSinglePatchReconstructionAlgorithm where {P<:AbstractSinglePatchAlgorithmParameters}
Expand All @@ -32,8 +32,7 @@ recoAlgorithmTypes(::Type{SinglePatchReconstruction}) = SystemMatrixBasedAlgorit
AbstractImageReconstruction.parameter(algo::SinglePatchReconstructionAlgorithm) = algo.params

function prepareSystemMatrix(reco::SinglePatchReconstructionParameter{L,S}) where {L<:AbstractSystemMatrixLoadingParameter, S<:AbstractLinearSolver}
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf)
sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, S, sf, grid, reco.arrayType)
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf, S, reco.arrayType)
return freqs, sf, grid, reco.arrayType
end

Expand Down Expand Up @@ -63,7 +62,7 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchRe
return gridresult(result, algo.grid, algo.sf)
end

process(algo::SinglePatchReconstructionAlgorithm, params::Union{ChannelWeightingParameters, WhiteningWeightingParameters}, u) = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))
process(algo::SinglePatchReconstructionAlgorithm, params::Union{W, ProcessResultCache{W}}, u) where {W <: Union{ChannelWeightingParameters, WhiteningWeightingParameters}} = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:DenseSystemMatixLoadingParameter, S}) where {S}
return nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ recoAlgorithmTypes(::Type{SinglePatchBGEstimationAlgorithm}) = SystemMatrixBased
AbstractImageReconstruction.parameter(algo::SinglePatchBGEstimationAlgorithm) = algo.origParam

function prepareSystemMatrix(reco::SinglePatchBGEstimationReconstructionParameter{L}) where {L<:AbstractSystemMatrixLoadingParameter}
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf)
sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, Kaczmarz, sf, grid, reco.arrayType)
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf, Kaczmarz, reco.arrayType)
return freqs, sf, grid, reco.arrayType
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ recoAlgorithmTypes(::Type{SinglePatchTemporalRegularizationAlgorithm}) = SystemM
AbstractImageReconstruction.parameter(algo::SinglePatchTemporalRegularizationAlgorithm) = algo.origParam

function prepareSystemMatrix(reco::SinglePatchTemporalRegularizationReconstructionParameter{L}) where {L<:AbstractSystemMatrixLoadingParameter}
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf)
sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, Kaczmarz, sf, grid, reco.arrayType)
freqs, sf, grid = process(AbstractMPIRecoAlgorithm, reco.sfLoad, reco.sf, Kaczmarz, reco.arrayType)
return freqs, sf, grid, reco.arrayType
end

Expand Down
1 change: 1 addition & 0 deletions src/MPIReco.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module MPIReco
@reexport using MPIFiles
const shape = MPIFiles.shape
using AbstractImageReconstruction
using LRUCache
using Adapt
@reexport using DSP
using ProgressMeter
Expand Down
4 changes: 4 additions & 0 deletions src/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ function MultiPatchOperatorHighLevel(bSF::MultiMPIFile, bMeas, freq, bgCorrectio
return FFOp
end

function process(::Type{<:AbstractMPIRecoAlgorithm}, params::AbstractMultiPatchOperatorParameter, bSF::MultiMPIFile, freq, gradient, FFPos, FFPosSF)
@info "Loading Multi Patch operator"
return MultiPatchOperator(bSF, freq; toKwargs(params)..., FFPos = FFPos, FFPosSF = FFPosSF, gradient = gradient)
end

function MultiPatchOperator(SF::MPIFile, freq, bgCorrection::Bool; kargs...)
return MultiPatchOperator(MultiMPIFile([SF]), freq, bgCorrection; kargs...)
Expand Down
27 changes: 16 additions & 11 deletions src/SystemMatrix/SystemMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ export AbstractSystemMatrixLoadingParameter
abstract type AbstractSystemMatrixLoadingParameter <: AbstractSystemMatrixParameter end

export DenseSystemMatixLoadingParameter
Base.@kwdef struct DenseSystemMatixLoadingParameter{F<:AbstractFrequencyFilterParameter, matT <: AbstractArray, G<:AbstractSystemMatrixGriddingParameter} <: AbstractSystemMatrixLoadingParameter
Base.@kwdef struct DenseSystemMatixLoadingParameter{F<:AbstractFrequencyFilterParameter, G<:AbstractSystemMatrixGriddingParameter} <: AbstractSystemMatrixLoadingParameter
freqFilter::F
gridding::G
bgCorrection::Bool = false
loadasreal::Bool = false
arrayType::Type{matT} = Array
end
function process(t::Type{<:AbstractMPIRecoAlgorithm}, params::DenseSystemMatixLoadingParameter, sf::MPIFile)
# Construct freqFilter
Expand Down Expand Up @@ -124,18 +123,24 @@ end
getSF(bSF, frequencies, sparseTrafo, solver::AbstractLinearSolver; kargs...) = getSF(bSF, frequencies, sparseTrafo, typeof(solver); kargs...)
function getSF(bSF, frequencies, sparseTrafo, solver::Type{<:AbstractLinearSolver}; arrayType = Array, kargs...)
SF, grid = getSF(bSF, frequencies, sparseTrafo; kargs...)
return prepareSF(solver, SF, grid, arrayType)
SF, grid = prepareSF(solver, SF, grid)
SF = adaptSF(arrayType, SF)
return SF, grid
end


function AbstractImageReconstruction.process(::Type{<:AbstractMPIRecoAlgorithm}, params::AbstractSystemMatrixLoadingParameter, solver::Type{<:AbstractLinearSolver}, sf, grid, arrayType)
return prepareSF(solver, sf, grid, arrayType)
end
function prepareSF(solver::Type{<:AbstractLinearSolver}, SF, grid, arrayType)
SF, grid = prepareSF(solver, SF, grid)
# adapt(Array, Sparse-CPU) results in dense array
return arrayType != Array ? adapt(arrayType, SF) : SF, grid
function AbstractImageReconstruction.process(type::Type{<:AbstractMPIRecoAlgorithm}, params::Union{L, ProcessResultCache{L}}, sf::MPIFile, solverT, arrayType = Array) where L <: AbstractSystemMatrixLoadingParameter
freqs, sf, grid = process(type, params, sf) # Cachable process
sf, grid = prepareSF(solverT, sf, grid)
sf = adaptSF(arrayType, sf)
return freqs, sf, grid
end


# Assumption SF is a (wrapped) CPU-array
# adapt(Array, Sparse-CPU) results in dense array, so we only want to adapt if necessary
adaptSF(arrayType, SF) = adapt(arrayType, SF)
adaptSF(arrayType::Type{<:Array}, SF) = SF

prepareSF(solver::Type{Kaczmarz}, SF, grid) = transpose(SF), grid
prepareSF(solver::Type{PseudoInverse}, SF, grid) = SVD(svd(transpose(SF))...), grid
prepareSF(solver::Type{DirectSolver}, SF, grid) = RegularizedLeastSquares.tikhonovLU(copy(transpose(SF))), grid
Expand Down
2 changes: 1 addition & 1 deletion test/ReconstructionGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ for arrayType in arrayTypes
@test isapprox(arraydata(c3), arraydata(c4))
end

@testset "Multi Patch Reconstruction: $arrayType" begin
arrayType == JLArray || @testset "Multi Patch Reconstruction: $arrayType" begin
dirs = ["1.mdf", "2.mdf", "3.mdf", "4.mdf"]
b = MultiMPIFile(joinpath.(datadir, "measurements", "20211226_203916_MultiPatch", dirs))

Expand Down

0 comments on commit 9a40108

Please sign in to comment.