Skip to content

Commit

Permalink
Merge pull request #31 from MagneticParticleImaging/nh/whitening
Browse files Browse the repository at this point in the history
Add Noise Whitening Weighting
  • Loading branch information
nHackel authored May 2, 2024
2 parents 2b910f0 + f127e24 commit d7355ed
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 15 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ ImageUtils = "0.2"
IniFile = "0.5"
LinearAlgebra = "1"
LinearOperators = "2.3.3"
LinearOperatorCollection = "1.0"
LinearOperatorCollection = "1.2"
MPIFiles = "0.13, 0.14, 0.15"
ProgressMeter = "1.2"
Reexport = "1.0"
RegularizedLeastSquares = "0.13"
RegularizedLeastSquares = "0.14"
SparseArrays = "1"
Statistics = "1"
ThreadPools = "2.1.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function process(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchRe
return gridresult(result, algo.grid, algo.sf)
end

process(algo::SinglePatchReconstructionAlgorithm, params::ChannelWeightingParameters, u) = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))
process(algo::SinglePatchReconstructionAlgorithm, params::Union{ChannelWeightingParameters, WhiteningWeightingParameters}, u) = map(real(eltype(algo.S)), process(typeof(algo), params, algo.freqs))

function getLinearOperator(algo::SinglePatchReconstructionAlgorithm, params::SinglePatchReconstructionParameter{<:DenseSystemMatixLoadingParameter, S}) where {S}
return nothing
Expand Down
10 changes: 9 additions & 1 deletion src/LeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,15 @@ function process(t::Type{<:AbstractMPIRecoAlgorithm}, params::LeastSquaresParame
reg, args = prepareRegularization(params.reg, params)
args[:reg] = reg

solv = createLinearSolver(params.solver, params.S; args..., weights = params.weights)
S = params.S
if !isnothing(params.weights)
S = ProdOp(WeightingOp(params.weights), S)
for l = 1:L
u[:, l] = params.weights.*u[:, l]
end
end

solv = createLinearSolver(params.solver, S; args...)

for l=1:L
d = solve!(solv, u[:, l])
Expand Down
20 changes: 10 additions & 10 deletions src/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ abstract type AbstractMultiPatchOperatorParameter <: AbstractMPIRecoParameters e
# MultiPatchOperator is a type that acts as the MPI system matrix but exploits
# its sparse structure.
# Its very important to keep this type typestable
mutable struct MultiPatchOperator{V<:AbstractMatrix, U<:Positions}
mutable struct MultiPatchOperator{T, V <: AbstractMatrix{T}, U<:Positions}
S::Vector{V}
grid::U
N::Int
Expand Down Expand Up @@ -294,7 +294,7 @@ function MultiPatchOperatorExpliciteMapping(SFs::MultiMPIFile, freq; bgCorrectio
# now that we have all grids we can calculate the indices within the recoGrid
xcc, xss = calculateLUT(grids, recoGrid)

return MultiPatchOperator(S, recoGrid, length(recoGrid), M*numPatches,
return MultiPatchOperator{eltype(first(S)), reduce(promote_type, typeof.(S)), typeof(recoGrid)}(S, recoGrid, length(recoGrid), M*numPatches,
RowToPatch, xcc, xss, sign, numPatches, patchToSMIdx)
end

Expand Down Expand Up @@ -411,7 +411,7 @@ function MultiPatchOperatorRegular(SFs::MultiMPIFile, freq; bgCorrection::Bool,

sign = ones(Int, M, numPatches)

return MultiPatchOperator(S, recoGrid, length(recoGrid), M*numPatches,
return MultiPatchOperator{eltype(first(S)), reduce(promote_type, typeof.(S)), typeof(recoGrid)}(S, recoGrid, length(recoGrid), M*numPatches,
RowToPatch, xcc, xss, sign, numPatches, patchToSMIdx)
end

Expand Down Expand Up @@ -509,38 +509,38 @@ function kaczmarz_update_!(A,x,beta,xs,xc,j,sign)
end
end

function initkaczmarz(Op::MultiPatchOperator,λ,weights::Vector)
T = typeof(real(Op.S[1][1]))
# TODO implement for ProdOp{WeightingOp, MultiPatchOperator}
function initkaczmarz(Op::MultiPatchOperator{T},λ) where T
denom = T[] #zeros(T,Op.M)
rowindex = Int64[] #zeros(Int64,Op.M)

MSub = div(Op.M,Op.nPatches)

if length(Op.S) == 1
for i=1:MSub
= rownorm²(Op.S[1],i)*weights[i]^2
= rownorm²(Op.S[1],i)#*weights[i]^2
if>0
for l=1:Op.nPatches
k = i+MSub*(l-1)
push!(denom,weights[i]^2/(s²+λ)) #denom[k] = weights[i]^2/(s²+λ)
push!(denom,1/(s²+λ)) #denom[k] = weights[i]^2/(s²+λ)
push!(rowindex,k) #rowindex[k] = k
end
end
end
else
for l=1:Op.nPatches
for i=1:MSub
= rownorm²(Op.S[Op.patchToSMIdx[l]],i)*weights[i]^2
= rownorm²(Op.S[Op.patchToSMIdx[l]],i)#*weights[i]^2
if>0
k = i+MSub*(l-1)
push!(denom,weights[i]^2/(s²+λ)) #denom[k] = weights[i]^2/(s²+λ)
push!(denom,1/(s²+λ)) #denom[k] = weights[i]^2/(s²+λ)
push!(rowindex,k) #rowindex[k] = k
end
end
end
end

denom, rowindex
Op, denom, rowindex
end


Expand Down
13 changes: 12 additions & 1 deletion src/Weighting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,19 @@ export ChannelWeightingParameters
Base.@kwdef struct ChannelWeightingParameters <: AbstractWeightingParameters
channelWeights::Vector{Float64} = [1.0, 1.0, 1.0]
end
process(::Type{<:AbstractMPIRecoAlgorithm}, params::ChannelWeightingParameters, data::Vector{CartesianIndex{2}}) = map(x-> params.channelWeights[x[2]], data)
process(::Type{<:AbstractMPIRecoAlgorithm}, params::ChannelWeightingParameters, freqs::Vector{CartesianIndex{2}}) = map(x-> params.channelWeights[x[2]], freqs)

export WhiteningWeightingParameters
Base.@kwdef struct WhiteningWeightingParameters <: AbstractWeightingParameters
whiteningMeas::MPIFile
tfCorrection::Bool = false
end
function process(::Type{<:AbstractMPIRecoAlgorithm}, params::WhiteningWeightingParameters, freqs::Vector{CartesianIndex{2}})
u_bg = getMeasurementsFD(params.whiteningMeas, false, frequencies=freqs, frames=measBGFrameIdx(params.whiteningMeas), bgCorrection = false, tfCorrection=false)
bg_std = std(u_bg, dims=3)
weights = minimum(abs.(vec(bg_std))) ./ abs.(vec(bg_std))
return weights
end
#=
baremodule WeightingType
None = 0
Expand Down
21 changes: 21 additions & 0 deletions test/Reconstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,25 @@ using MPIReco
exportImage(joinpath(imgdir, "Reconstruction7c.png"), Array(c7c[1,:,:,1,1]))
@test compareImg("Reconstruction7c.png")


setAll!(plan, :weightingParams, ChannelWeightingParameters(channelWeights = [1.0, 0.0, 1.0]))
c7d = reconstruct(build(plan), b)
setAll!(plan, :weightingParams, NoWeightingParameters())
setAll!(plan, :recChannels, 1:1)
c7e = reconstruct(build(plan), b)
setAll!(plan, :recChannels, 1:2)
@test isapprox(arraydata(c7d), arraydata(c7e))

setAll!(plan, :weightingParams, ChannelWeightingParameters(channelWeights = [0.0, 1.0, 1.0]))
c7f = reconstruct(build(plan), b)
setAll!(plan, :weightingParams, NoWeightingParameters())
setAll!(plan, :recChannels, 2:2)
c7g = reconstruct(build(plan), b)
setAll!(plan, :recChannels, 1:2)
@test isapprox(arraydata(c7f), arraydata(c7g))

setAll!(plan, :weightingParams, WhiteningWeightingParameters(whiteningMeas = bSF))
c8 = reconstruct(build(plan), b)
exportImage(joinpath(imgdir, "Reconstruction8.png"), Array(c8[1,:,:,1,1]))
@test compareImg("Reconstruction8.png")
end
Binary file added test/correct/Reconstruction8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d7355ed

Please sign in to comment.