From d34d4a40284398f0050c18e13edcc2a6b52e98a3 Mon Sep 17 00:00:00 2001 From: nHackel Date: Wed, 28 Feb 2024 18:02:17 +0100 Subject: [PATCH] Update Tikhonov matrix for Kaczmarz --- src/Kaczmarz.jl | 58 +++++++++++++++++++++++--------------------- test/testKaczmarz.jl | 4 +-- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/Kaczmarz.jl b/src/Kaczmarz.jl index ea2eb66d..0a46e5fb 100644 --- a/src/Kaczmarz.jl +++ b/src/Kaczmarz.jl @@ -11,7 +11,7 @@ mutable struct Kaczmarz{matT,R,T,U,RN} <: AbstractRowActionSolver rowIndexCycle::Vector{Int64} x::Vector{T} vl::Vector{T} - εw::Vector{T} + εw::T τl::T αl::T randomized::Bool @@ -20,7 +20,6 @@ mutable struct Kaczmarz{matT,R,T,U,RN} <: AbstractRowActionSolver shuffleRows::Bool seed::Int64 iterations::Int64 - regMatrix::Union{Nothing,Vector{U}} # Tikhonov regularization matrix normalizeReg::AbstractRegularizationNormalization end @@ -51,17 +50,10 @@ function Kaczmarz(A , shuffleRows::Bool = false , seed::Int = 1234 , iterations::Int = 10 - , regMatrix = nothing ) T = real(eltype(A)) - # Apply Tikhonov regularization matrix - if regMatrix !== nothing - regMatrix = T.(regMatrix) # make sure regMatrix has the same element type as A - A = transpose(1 ./ sqrt.(regMatrix)) .* A # apply Tikhonov regularization to system matrix - end - # Prepare regularization terms reg = isa(reg, AbstractVector) ? reg : [reg] reg = normalize(Kaczmarz, normalizeReg, reg, A, nothing) @@ -73,6 +65,11 @@ function Kaczmarz(A deleteat!(reg, idx) end + # Tikhonov matrix is only valid with NoNormalization or SystemMatrixBasedNormalization + if λ(L2) isa Vector && !(normalizeReg isa NoNormalization || normalizeReg isa SystemMatrixBasedNormalization) + error("Tikhonov matrix for Kaczmarz is only valid with no or system matrix based normalization") + end + indices = findsinks(AbstractProjectionRegularization, reg) other = AbstractRegularization[reg[i] for i in indices] deleteat!(reg, indices) @@ -84,7 +81,7 @@ function Kaczmarz(A other = identity.(other) # setup denom and rowindex - denom, rowindex = initkaczmarz(A, λ(L2)) + A, denom, rowindex = initkaczmarz(A, λ(L2)) rowIndexCycle = collect(1:length(rowindex)) probabilities = eltype(denom)[] if randomized @@ -97,13 +94,13 @@ function Kaczmarz(A u = zeros(eltype(A),M) x = zeros(eltype(A),N) vl = zeros(eltype(A),M) - εw = zeros(eltype(A), length(rowindex)) + εw = zero(eltype(A)) τl = zero(eltype(A)) αl = zero(eltype(A)) return Kaczmarz(A, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl, randomized, subMatrixSize, probabilities, shuffleRows, - Int64(seed), iterations, regMatrix, + Int64(seed), iterations, normalizeReg) end @@ -121,7 +118,8 @@ function init!(solver::Kaczmarz, b; x0 = 0) # λ changed => recompute denoms if λ_ != λ_prev - solver.denom, solver.rowindex = initkaczmarz(solver.A, λ_) + # A must be unchanged, since we do not store the original SM + _, solver.denom, solver.rowindex = initkaczmarz(solver.A, λ_) solver.rowIndexCycle = collect(1:length(rowindex)) if solver.randomized solver.probabilities = T.(rowProbabilities(solver.A, rowindex)) @@ -140,17 +138,18 @@ function init!(solver::Kaczmarz, b; x0 = 0) solver.vl .= 0 solver.u .= b - solver.ɛw .= sqrt.(λ_) + if λ_ isa Vector + solver.ɛw = 0 + else + solver.ɛw = sqrt(λ_) + end end -function solversolution(solver::Kaczmarz) - # backtransformation of solution with Tikhonov matrix - if solver.regMatrix !== nothing - return solver.x .* (1 ./ sqrt.(solver.regMatrix)) - end - return solver.x +function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:Vector}, RN <: Union{R, AbstractNestedRegularization{<:R}}} + return solver.x .* (1 ./ sqrt.(λ(solver.L2))) end +solversolution(solver::Kaczmarz) = solver.x solverconvergence(solver::Kaczmarz) = (; :residual => norm(solver.vl)) function iterate(solver::Kaczmarz, iteration::Int=0) @@ -177,9 +176,9 @@ end iterate_row_index(solver::Kaczmarz, A::AbstractLinearSolver, row, index) = iterate_row_index(solver, Matrix(A[row, :]), row, index) function iterate_row_index(solver::Kaczmarz, A, row, index) solver.τl = dot_with_matrix_row(A,solver.x,row) - solver.αl = solver.denom[index]*(solver.u[row]-solver.τl-solver.ɛw[index]*solver.vl[row]) + solver.αl = solver.denom[index]*(solver.u[row]-solver.τl-solver.ɛw*solver.vl[row]) kaczmarz_update!(A,solver.x,row,solver.αl) - solver.vl[row] += solver.αl*solver.ɛw[index] + solver.vl[row] += solver.αl*solver.ɛw end @inline done(solver::Kaczmarz,iteration::Int) = iteration>=solver.iterations @@ -208,23 +207,28 @@ end This function saves the denominators to compute αl in denom and the rowindices, which lead to an update of x in rowindex. """ -initkaczmarz(A, λ::Number) = initkaczmarz(A, Iterators.repeated(λ, size(A, 2))) function initkaczmarz(A,λ) T = real(eltype(A)) denom = T[] rowindex = Int64[] - @assert length(λ) == size(A, 2) - for (i, λrow) in enumerate(λ) + for i = 1:size(A, 1) s² = rownorm²(A,i) if s²>0 - push!(denom,1/(s²+λrow)) + push!(denom,1/(s²+λ)) push!(rowindex,i) end end - denom, rowindex + return A, denom, rowindex +end +function initkaczmarz(A, λ::Vector) + λ = real(eltype(A)).(λ) + A = initikhonov(A, λ) + return initkaczmarz(A, 0) end +initikhonov(A, λ) = transpose((1 ./ sqrt.(λ)) .* transpose(A)) # optimize structure for row access +initikhonov(prod::ProdOp{Tc, WeightingOp{T}, matT}, λ) where {T, Tc<:Union{T, Complex{T}}, matT} = ProdOp(prod.A, initikhonov(prod.B, λ)) ### kaczmarz_update! ### """ diff --git a/test/testKaczmarz.jl b/test/testKaczmarz.jl index 2121d2c8..0e85c8c9 100644 --- a/test/testKaczmarz.jl +++ b/test/testKaczmarz.jl @@ -42,7 +42,7 @@ end regMatrix = rand(2) # Tikhonov matrix solver = Kaczmarz - S = createLinearSolver(solver, A, iterations=200, regMatrix=regMatrix) + S = createLinearSolver(solver, A, iterations=200, reg=[L2Regularization(regMatrix)]) x_approx = solve!(S,b) #@info "Testing solver $solver ...: $x == $x_approx" @test norm(x - x_approx) / norm(x) ≈ 0 atol=0.1 @@ -61,7 +61,7 @@ end # @show A, x, regMatrix # use regularization matrix - S = createLinearSolver(solver, A, iterations=100, regMatrix=regMatrix) + S = createLinearSolver(solver, A, iterations=100, reg=[L2Regularization(regMatrix)]) x_matrix = solve!(S,b) # use standard reconstruction