Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearOperator based Weighting for Kaczmarz & CGNR #74

Merged
merged 21 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ julia = "1.9"
StatsBase = "0.33, 0.34"
FLoops = "0.2"
VectorizationBase = "0.19, 0.21"
LinearOperatorCollection = "1.0"
LinearOperatorCollection = "1.2"
LinearOperators = "2.3.3"
FFTW = "1.0"

Expand Down
19 changes: 6 additions & 13 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
αl::T
βl::T
ζl::T
weights::vecT
iterations::Int64
relTol::Float64
z0::Float64
Expand Down Expand Up @@ -46,7 +45,6 @@
; AHA = A'*A
, reg = L2Regularization(zero(real(eltype(AHA))))
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, weights::AbstractVector = similar(AHA, 0)
, iterations::Int = 10
, relTol::Real = eps(real(eltype(AHA)))
)
Expand Down Expand Up @@ -83,7 +81,7 @@


return CGNR(A, AHA,
L2, other, x, x₀, pl, vl, αl, βl, ζl, weights, iterations, relTol, 0.0, normalizeReg)
L2, other, x, x₀, pl, vl, αl, βl, ζl, iterations, relTol, 0.0, normalizeReg)
end

"""
Expand All @@ -106,16 +104,7 @@
end

#x₀ = Aᶜ*rl, where ᶜ denotes complex conjugation
if solver.A === nothing
!isempty(solver.weights) && @info "weights are being ignored if the backprojection is pre-computed"
solver.x₀ .= b
else
if isempty(solver.weights)
mul!(solver.x₀, adjoint(solver.A), b)
else
mul!(solver.x₀, adjoint(solver.A), b .* solver.weights)
end
end
initCGNR(solver.x₀, solver.A, b)

solver.z0 = norm(solver.x₀)
copyto!(solver.pl, solver.x₀)
Expand All @@ -124,6 +113,10 @@
solver.L2 = normalize(solver, solver.normalizeReg, solver.L2, solver.A, b)
end

initCGNR(x₀, A, b) = mul!(x₀, adjoint(A), b)
initCGNR(x₀, prod::ProdOp{T, <:WeightingOp, matT}, b) where {T, matT} = mul!(x₀, adjoint(prod.B), b .* prod.A.weights)

Check warning on line 117 in src/CGNR.jl

View check run for this annotation

Codecov / codecov/patch

src/CGNR.jl#L117

Added line #L117 was not covered by tests
initCGNR(x₀, ::Nothing, b) = x₀ .= b

solverconvergence(solver::CGNR) = (; :residual => norm(solver.x₀))

"""
Expand Down
119 changes: 69 additions & 50 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export kaczmarz
export Kaczmarz

mutable struct Kaczmarz{matT,T,U,R,RN} <: AbstractRowActionSolver
mutable struct Kaczmarz{matT,R,T,U,RN} <: AbstractRowActionSolver
A::matT
u::Vector{T}
L2::R
Expand All @@ -11,17 +11,15 @@
rowIndexCycle::Vector{Int64}
x::Vector{T}
vl::Vector{T}
εw::Vector{T}
εw::T
τl::T
αl::T
weights::Vector{U}
randomized::Bool
subMatrixSize::Int64
probabilities::Vector{U}
shuffleRows::Bool
seed::Int64
iterations::Int64
regMatrix::Union{Nothing,Vector{U}} # Tikhonov regularization matrix
normalizeReg::AbstractRegularizationNormalization
end

Expand All @@ -36,7 +34,6 @@
# Optional Keyword Arguments
* `reg::AbstractParameterizedRegularization` - regularization term
* `normalizeReg::AbstractRegularizationNormalization` - regularization normalization scheme; options are `NoNormalization()`, `MeasurementBasedNormalization()`, `SystemMatrixBasedNormalization()`
* `weights::AbstractVector` - weights for the data term
* `randomized::Bool` - randomize Kacmarz algorithm
* `subMatrixFraction::Real` - fraction of rows used in randomized Kaczmarz algorithm
* `shuffleRows::Bool` - randomize Kacmarz algorithm
Expand All @@ -46,25 +43,17 @@
See also [`createLinearSolver`](@ref), [`solve!`](@ref).
"""
function Kaczmarz(A
; reg = L2Regularization(0)
; reg = L2Regularization(zero(real(eltype(A))))
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, weights = nothing
, randomized::Bool = false
, subMatrixFraction::Real = 0.15
, shuffleRows::Bool = false
, seed::Int = 1234
, iterations::Int = 10
, regMatrix = nothing
)

T = real(eltype(A))

# Apply Tikhonov regularization matrix
if regMatrix !== nothing
regMatrix = T.(regMatrix) # make sure regMatrix has the same element type as A
A = transpose(1 ./ sqrt.(regMatrix)) .* A # apply Tikhonov regularization to system matrix
end

# Prepare regularization terms
reg = isa(reg, AbstractVector) ? reg : [reg]
reg = normalize(Kaczmarz, normalizeReg, reg, A, nothing)
Expand All @@ -76,6 +65,11 @@
deleteat!(reg, idx)
end

# Tikhonov matrix is only valid with NoNormalization or SystemMatrixBasedNormalization
if λ(L2) isa Vector && !(normalizeReg isa NoNormalization || normalizeReg isa SystemMatrixBasedNormalization)
error("Tikhonov matrix for Kaczmarz is only valid with no or system matrix based normalization")

Check warning on line 70 in src/Kaczmarz.jl

View check run for this annotation

Codecov / codecov/patch

src/Kaczmarz.jl#L70

Added line #L70 was not covered by tests
end

indices = findsinks(AbstractProjectionRegularization, reg)
other = AbstractRegularization[reg[i] for i in indices]
deleteat!(reg, indices)
Expand All @@ -86,28 +80,27 @@
end
other = identity.(other)


# make sure weights are not empty
w = (weights!=nothing ? weights : ones(T,size(A,1)))

# setup denom and rowindex
denom, rowindex = initkaczmarz(A, λ(L2), w)
A, denom, rowindex = initkaczmarz(A, λ(L2))
rowIndexCycle = collect(1:length(rowindex))
probabilities = T.(rowProbabilities(A, rowindex))
probabilities = eltype(denom)[]
if randomized
probabilities = T.(rowProbabilities(A, rowindex))
end

M,N = size(A)
subMatrixSize = round(Int, subMatrixFraction*M)

u = zeros(eltype(A),M)
x = zeros(eltype(A),N)
vl = zeros(eltype(A),M)
εw = zeros(eltype(A),length(rowindex))
εw = zero(eltype(A))
τl = zero(eltype(A))
αl = zero(eltype(A))

return Kaczmarz(A, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl,
T.(w), randomized, subMatrixSize, probabilities, shuffleRows,
Int64(seed), iterations, regMatrix,
randomized, subMatrixSize, probabilities, shuffleRows,
Int64(seed), iterations,
normalizeReg)
end

Expand All @@ -117,36 +110,46 @@
(re-) initializes the Kacmarz iterator
"""
function init!(solver::Kaczmarz, b; x0 = 0)
λ_prev = λ(solver.L2)
solver.L2 = normalize(solver, solver.normalizeReg, solver.L2, solver.A, b)
solver.reg = normalize(solver, solver.normalizeReg, solver.reg, solver.A, b)

λ_ = λ(solver.L2)

# λ changed => recompute denoms
if λ_ != λ_prev
# A must be unchanged, since we do not store the original SM
_, solver.denom, solver.rowindex = initkaczmarz(solver.A, λ_)
solver.rowIndexCycle = collect(1:length(rowindex))
if solver.randomized
solver.probabilities = T.(rowProbabilities(solver.A, rowindex))

Check warning on line 125 in src/Kaczmarz.jl

View check run for this annotation

Codecov / codecov/patch

src/Kaczmarz.jl#L122-L125

Added lines #L122 - L125 were not covered by tests
end
end

if solver.shuffleRows || solver.randomized
Random.seed!(solver.seed)
end
if solver.shuffleRows
shuffle!(solver.rowIndexCycle)
end
solver.u .= b

# start vector
solver.x .= x0
solver.vl .= 0

for i=1:length(solver.rowindex)
j = solver.rowindex[i]
solver.ɛw[i] = sqrt(λ_) / solver.weights[j]
solver.u .= b
if λ_ isa Vector
solver.ɛw = 0
else
solver.ɛw = sqrt(λ_)
end
end

function solversolution(solver::Kaczmarz)
# backtransformation of solution with Tikhonov matrix
if solver.regMatrix !== nothing
return solver.x .* (1 ./ sqrt.(solver.regMatrix))
end
return solver.x

function solversolution(solver::Kaczmarz{matT, RN}) where {matT, R<:L2Regularization{<:Vector}, RN <: Union{R, AbstractNestedRegularization{<:R}}}
return solver.x .* (1 ./ sqrt.(λ(solver.L2)))
end
solversolution(solver::Kaczmarz) = solver.x
solverconvergence(solver::Kaczmarz) = (; :residual => norm(solver.vl))

function iterate(solver::Kaczmarz, iteration::Int=0)
Expand All @@ -159,11 +162,8 @@
end

for i in usedIndices
j = solver.rowindex[i]
solver.τl = dot_with_matrix_row(solver.A,solver.x,j)
solver.αl = solver.denom[i]*(solver.u[j]-solver.τl-solver.ɛw[i]*solver.vl[j])
kaczmarz_update!(solver.A,solver.x,j,solver.αl)
solver.vl[j] += solver.αl*solver.ɛw[i]
row = solver.rowindex[i]
iterate_row_index(solver, solver.A, row, i)
end

for r in solver.reg
Expand All @@ -173,48 +173,62 @@
return solver.vl, iteration+1
end

iterate_row_index(solver::Kaczmarz, A::AbstractLinearSolver, row, index) = iterate_row_index(solver, Matrix(A[row, :]), row, index)

Check warning on line 176 in src/Kaczmarz.jl

View check run for this annotation

Codecov / codecov/patch

src/Kaczmarz.jl#L176

Added line #L176 was not covered by tests
function iterate_row_index(solver::Kaczmarz, A, row, index)
solver.τl = dot_with_matrix_row(A,solver.x,row)
solver.αl = solver.denom[index]*(solver.u[row]-solver.τl-solver.ɛw*solver.vl[row])
kaczmarz_update!(A,solver.x,row,solver.αl)
solver.vl[row] += solver.αl*solver.ɛw
end

@inline done(solver::Kaczmarz,iteration::Int) = iteration>=solver.iterations


"""
This function calculates the probabilities of the rows of the system matrix
"""

function rowProbabilities(A::AbstractMatrix, rowindex)
M,N = size(A)
normS = norm(A)
function rowProbabilities(A, rowindex)
normA² = rownorm²(A, 1:size(A, 1))
p = zeros(length(rowindex))
for i=1:length(rowindex)
j = rowindex[i]
p[i] = (norm(A[j,:]))^2 / (normS)^2
p[i] = rownorm²(A, j) / (normA²)
end

return p
end


### initkaczmarz ###

"""
initkaczmarz(A::AbstractMatrix,λ,weights::Vector)
initkaczmarz(A::AbstractMatrix,λ)

This function saves the denominators to compute αl in denom and the rowindices,
which lead to an update of x in rowindex.
"""
function initkaczmarz(A::AbstractMatrix,λ,weights::Vector)
T = typeof(real(A[1]))
function initkaczmarz(A)
T = real(eltype(A))
denom = T[]
rowindex = Int64[]

for i=1:size(A,1)
s² = rownorm²(A,i)*weights[i]^2
for i = 1:size(A, 1)
s² = rownorm²(A,i)
if s²>0
push!(denom,weights[i]^2/(s²+λ))
push!(denom,1/(s²+λ))
push!(rowindex,i)
end
end
denom, rowindex
return A, denom, rowindex
end
function initkaczmarz(A, λ::Vector)
λ = real(eltype(A)).(λ)
A = initikhonov(A, λ)
return initkaczmarz(A, 0)
end

initikhonov(A, λ) = transpose((1 ./ sqrt.(λ)) .* transpose(A)) # optimize structure for row access
initikhonov(prod::ProdOp{Tc, WeightingOp{T}, matT}, λ) where {T, Tc<:Union{T, Complex{T}}, matT} = ProdOp(prod.A, initikhonov(prod.B, λ))

Check warning on line 231 in src/Kaczmarz.jl

View check run for this annotation

Codecov / codecov/patch

src/Kaczmarz.jl#L231

Added line #L231 was not covered by tests
### kaczmarz_update! ###

"""
Expand Down Expand Up @@ -242,6 +256,11 @@
end
end

function kaczmarz_update!(prod::ProdOp{Tc, WeightingOp{T}, matT}, x::Vector, k, beta) where {T, Tc<:Union{T, Complex{T}}, matT}
weight = prod.A.weights[k]
kaczmarz_update!(prod.B, x, k, weight*beta) # only for real weights
end

# kaczmarz_update! with manual simd optimization
for (T,W, WS,shufflevectorMask,vσ) in [(Float32,:WF32,:WF32S,:shufflevectorMaskF32,:vσF32),(Float64,:WF64,:WF64S,:shufflevectorMaskF64,:vσF64)]
eval(quote
Expand Down
6 changes: 3 additions & 3 deletions src/Regularization/NormalizedRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
struct NormalizedRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
reg::R
factor::T
NormalizedRegularization(reg::R, factor) where {T, R <: AbstractParameterizedRegularization{<:AbstractArray{T}}} = new{T, R, R}(reg, factor)

Check warning on line 33 in src/Regularization/NormalizedRegularization.jl

View check run for this annotation

Codecov / codecov/patch

src/Regularization/NormalizedRegularization.jl#L33

Added line #L33 was not covered by tests
NormalizedRegularization(reg::R, factor) where {T, R <: AbstractParameterizedRegularization{T}} = new{T, R, R}(reg, factor)
NormalizedRegularization(reg::R, factor) where {T, RN <: AbstractParameterizedRegularization{T}, R<:AbstractNestedRegularization{RN}} = new{T, RN, R}(reg, factor)
end
Expand All @@ -43,17 +44,16 @@

normalize(::SystemMatrixBasedNormalization, ::Nothing, _) = error("SystemMatrixBasedNormalization requires supplying A to the constructor of the solver")

function normalize(::SystemMatrixBasedNormalization, A::AbstractArray{T}, b) where {T}
function normalize(::SystemMatrixBasedNormalization, A, b)

Check warning on line 47 in src/Regularization/NormalizedRegularization.jl

View check run for this annotation

Codecov / codecov/patch

src/Regularization/NormalizedRegularization.jl#L47

Added line #L47 was not covered by tests
M = size(A, 1)
N = size(A, 2)

energy = zeros(T, M)
energy = zeros(real(eltype(A)), M)

Check warning on line 51 in src/Regularization/NormalizedRegularization.jl

View check run for this annotation

Codecov / codecov/patch

src/Regularization/NormalizedRegularization.jl#L51

Added line #L51 was not covered by tests
for m=1:M
energy[m] = sqrt(rownorm²(A,m))
end

trace = norm(energy)^2/N
# TODO where setlamda? here we dont know λ
return trace
end
normalize(::NoNormalization, A, b) = nothing
Expand Down
4 changes: 2 additions & 2 deletions src/Regularization/ScaledRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Nested regularization term that applies a `scalefactor` to the regularization pa

See also [`scalefactor`](@ref), [`λ`](@ref), [`innerreg`](@ref).
"""
abstract type AbstractScaledRegularization{T, S<:AbstractParameterizedRegularization{T}} <: AbstractNestedRegularization{S} end
abstract type AbstractScaledRegularization{T, S<:AbstractParameterizedRegularization{<:Union{T, <:AbstractArray{T}}}} <: AbstractNestedRegularization{S} end
"""
scalescalefactor(reg::AbstractScaledRegularization)

Expand All @@ -20,7 +20,7 @@ return `λ` of `inner` regularization term scaled by `scalefactor(reg)`.

See also [`scalefactor`](@ref), [`innerreg`](@ref).
"""
λ(reg::AbstractScaledRegularization) = λ(innerreg(reg)) * scalefactor(reg)
λ(reg::AbstractScaledRegularization) = λ(innerreg(reg)) .* scalefactor(reg)

export FixedScaledRegularization
struct FixedScaledRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
Expand Down
Loading
Loading