diff --git a/src/AbstractImageReconstruction.jl b/src/AbstractImageReconstruction.jl index e85a1ed..d27e9c4 100644 --- a/src/AbstractImageReconstruction.jl +++ b/src/AbstractImageReconstruction.jl @@ -5,7 +5,7 @@ using ThreadPools using Scratch using RegularizedLeastSquares -import Base: put!, take!, fieldtypes, fieldtype, ismissing, propertynames, parent +import Base: put!, take!, fieldtypes, fieldtype, ismissing, propertynames, parent, hash include("AlgorithmInterface.jl") include("StructTransforms.jl") diff --git a/src/RecoPlans/Cache.jl b/src/RecoPlans/Cache.jl index 1b67c50..34013c6 100644 --- a/src/RecoPlans/Cache.jl +++ b/src/RecoPlans/Cache.jl @@ -4,10 +4,26 @@ Base.@kwdef mutable struct CachedProcessParameter{P <: AbstractImageReconstructi cache::Dict{UInt64, Any} = Dict{UInt64, Any}() lock::ReentrantLock = ReentrantLock() end -function process(algo, param::CachedProcessParameter, inputs...) +function process(algo::AbstractImageReconstructionAlgorithm, param::CachedProcessParameter, inputs...) lock(param.lock) do - id = hash(param, hash(inputs)) - result = get(param.cache, id, process(algo, param.param, inputs...)) + id = hash(param.param, hash(inputs)) + if haskey(param.cache, id) + result = param.cache[id] + else + result = process(algo, param.param, inputs...) + end + param.cache[id] = result + return result + end +end +function process(algo::Type{<:AbstractImageReconstructionAlgorithm}, param::CachedProcessParameter, inputs...) + lock(param.lock) do + id = hash(param.param, hash(inputs)) + if haskey(param.cache, id) + result = param.cache[id] + else + result = process(algo, param.param, inputs...) + end param.cache[id] = result return result end @@ -38,4 +54,17 @@ function Base.empty!(cache::CachedProcessParameter) lock(cache.lock) do empty!(cache.cache) end +end + +""" + hash(parameter::AbstractImageReconstructionParameters, h) + +Default hash function for image reconstruction paramters. Uses `nameof` the parameter and all fields not starting with `_` to compute the hash. +""" +function Base.hash(parameter::T, h::UInt64) where T <: AbstractImageReconstructionParameters + h = hash(nameof(T), h) + for field in filter(f -> !startswith(string(f), "_"), fieldnames(T)) + h = hash(hash(getproperty(parameter, field)), h) + end + return h end \ No newline at end of file