Skip to content

Commit

Permalink
Consolidate solve
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobAsslaender committed Dec 5, 2023
1 parent 738a544 commit 5bb0cd8
Show file tree
Hide file tree
Showing 12 changed files with 192 additions and 448 deletions.
29 changes: 0 additions & 29 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,35 +176,6 @@ function init!(solver::ADMM, b::AbstractVector{T}; x0=0) where T

end

"""
solve(solver::ADMM, b; startVector=similar(b,0), solverInfo=nothing)
solves an inverse problem using ADMM.
# Arguments
* `solver::ADMM` - the solver containing both system matrix and regularizer
* `b::AbstractVector` - data vector if `A` was supplied to the solver, back-projection of the data otherwise
# Keyword Arguments
* `startVector::AbstractVector` - initial guess for the solution
* `solverInfo::SolverInfo` - solverInfo object
when a `SolverInfo` object is passed, the residuals are stored in `solverInfo.convMeas`.
"""
function solve(solver::ADMM, b; x0=0, solverInfo=nothing)
# initialize solver parameters
init!(solver, b; x0)

# log solver information
solverInfo !== nothing && storeInfo(solverInfo,solver.x,solver.rᵏ...,solver.sᵏ...)

# perform ADMM iterations
for (iteration, item) = enumerate(solver)
solverInfo !== nothing && storeInfo(solverInfo,solver.x,solver.rᵏ...,solver.sᵏ...)
end

return solver.x
end

"""
iterate(it::ADMM, iteration::Int=0)
Expand Down
46 changes: 7 additions & 39 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
creates an `CGNR` object for the forward operator `A` or normal operator `AHA`.
# Required Keyword Arguments
# Required Arguments
* `A` - forward operator
OR
* `AHA` - normal operator (as a keyword argument)
Expand Down Expand Up @@ -87,21 +87,20 @@ function CGNR(A
end

"""
init!(solver::CGNR{vecT,T,Tsparse}, b::vecT; x0::vecT=similar(b,0)) where {vecT,T,Tsparse,matT}
init!(solver::CGNR, b; x0 = 0)
(re-) initializes the CGNR iterator
"""
function init!(solver::CGNR, b; x0=0)
function init!(solver::CGNR, b; x0 = 0)
solver.pl .= 0 #temporary vector
solver.vl .= 0 #temporary vector
solver.αl = 0 #temporary scalar
solver.βl = 0 #temporary scalar
solver.ζl = 0 #temporary scalar

solver.αl = 0 #temporary scalar
solver.βl = 0 #temporary scalar
solver.ζl = 0 #temporary scalar
if all(x0 .== 0)
solver.x .= 0
else
solver.A === nothing && error("providing a x0 requires solver.A to be defined")
solver.A === nothing && error("providing x0 requires solver.A to be defined")
solver.x .= x0
mul!(b, solver.A, solver.x, -1, 1)
end
Expand All @@ -126,37 +125,6 @@ function init!(solver::CGNR, b; x0=0)
end


"""
solve(solver::CGNR, b; x0=0, solverInfo=nothing)
solves an inverse problem using CGNR.
# Arguments
* `solver::CGNR` - the solver containing both system matrix and regularizer
* `b::AbstractVector` - data vector if `A` was supplied to the solver, back-projection of the data otherwise
# Keyword Arguments
* `x0::AbstractVector` - initial guess for the solution
* `solverInfo::SolverInfo` - solverInfo object
when a `SolverInfo` object is passed, the residuals are stored in `solverInfo.convMeas`.
"""
function solve(solver::CGNR, b; x0=0, solverInfo=nothing)
# initialize solver parameters
init!(solver, b; x0=x0)

# log solver information
solverInfo !== nothing && storeInfo(solverInfo, solver.x, norm(solver.x₀))

# perform CGNR iterations
for (iteration, item) = enumerate(solver)
solverInfo !== nothing && storeInfo(solverInfo, solver.x, norm(solver.x₀))
end

return solver.x
end


"""
iterate(solver::CGNR{vecT,T,Tsparse}, iteration::Int=0) where {vecT,T,Tsparse}
Expand Down
109 changes: 37 additions & 72 deletions src/DAXConstrained.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct DaxConstrained{matT,T,Tsparse,U} <: AbstractRowActionSolver
Bnorm²::Vector{Float64}
denom::Vector{U}
rowindex::Vector{Int64}
zk::Vector{T}
x::Vector{T}
bk::Vector{T}
bc::Vector{T}
xl::Vector{T}
Expand Down Expand Up @@ -51,99 +51,64 @@ function DaxConstrained(A
, iterationsInner::Int=2
)

T = typeof(real(A[1]))
T = eltype(A)
rT = real(eltype(A))
M,N = size(A)

# setup denom and rowindex
denom, rowindex = initkaczmarzconstraineddax(A,λ,weights)

# set basis transformation
sparseTrafo==nothing ? B=Matrix{T}(I, size(A,2), size(A,2)) : B=sparseTrafo
sparseTrafo === nothing ? B=Matrix{rT}(I, size(A,2), size(A,2)) : B=sparseTrafo
Bnorm² = [rownorm²(B,i) for i=1:size(B,2)]

if b !== nothing
u = b
else
u = zeros(eltype(A),M)
u = zeros(T,M)
end

zk = zeros(eltype(A),N)
bk = zeros(eltype(A),M)
bc = zeros(T,size(B,2))
xl = zeros(eltype(A),N)
yl = zeros(eltype(A),M)
yc = zeros(eltype(A),N)
δc = zeros(eltype(A),N)
εw = zeros(eltype(A),length(rowindex))
τl = zero(eltype(A))
αl = zero(eltype(A))

return DaxConstrained(A,u,Float64(λ),B,Bnorm²,denom,rowindex,zk,bk,bc,xl,yl,yc,δc,εw,τl,αl
,T.(weights),iterations,iterationsInner)
x = zeros(T,N)
bk = zeros(T,M)
bc = zeros(rT,size(B,2))
xl = zeros(T,N)
yl = zeros(T,M)
yc = zeros(T,N)
δc = zeros(T,N)
εw = zeros(T,length(rowindex))
τl = zero(T)
αl = zero(T)

return DaxConstrained(A,u,Float64(λ),B,Bnorm²,denom,rowindex,x,bk,bc,xl,yl,yc,δc,εw,τl,αl
,rT.(weights),iterations,iterationsInner)
end

function init!(solver::DaxConstrained
; A::matT=solver.A
, λ::Real=solver.λ
, u::Vector{T}=eltype(A)[]
, zk::Vector{T}=eltype(A)[]
, weights::Vector{Float64}=solver.weights) where {matT,T}
function init!(solver::DaxConstrained, b; x0 = 0)
solver.u .= b
solver.x .= x0

if A != solver.A
denom, rowindex = initkaczmarzconstraineddax(A,λ,weights)
end
solver.λ = Float64(λ)

solver.u[:] .= u
solver.weights=weights

# start vector
if isempty(zk)
solver.zk[:] .= zeros(T,size(A,2))
else
solver.zk[:] .= x
end

solver.bk[:] .= zero(T)
solver.bc[:] .= zero(T)
solver.xl[:] .= zero(T)
solver.yl[:] .= zero(T)
solver.yc[:] .= zero(T)
solver.δc[:] .= zero(T)
solver.αl = zero(T) #temporary scalar
solver.τl = zero(T) #temporary scalar
solver.bk .= 0
solver.bc .= 0
solver.xl .= 0
solver.yl .= 0
solver.yc .= 0
solver.δc .= 0
solver.αl = 0
solver.τl = 0

for i=1:length(solver.rowindex)
j = solver.rowindex[i]
solver.ɛw[i] = sqrt(solver.λ)/weights[j]
solver.ɛw[i] = sqrt(solver.λ)/solver.weights[j]
end
end

function solve(solver::DaxConstrained, u::Vector{T}; λ::Real=solver.λ
, A::matT=solver.A, startVector::Vector{T}=eltype(A)[]
, weights::Vector=solver.weights
, solverInfo=nothing, kargs...) where {T,matT}

# initialize solver parameters
init!(solver; A=A, λ=λ, u=u, zk=startVector, weights=weights)

# log solver information
solverInfo != nothing && storeInfo(solverInfo,solver.zk,norm(solver.bk))

# perform CGNR iterations
for (iteration, item) = enumerate(solver)
solverInfo != nothing && storeInfo(solverInfo,solver.zk,norm(solver.bk))
end

return solver.zk
end

function iterate(solver::DaxConstrained, iteration::Int=0)
if done(solver,iteration) return nothing end

# bk = u-A'*zk
# bk = u-A'*x
copyto!(solver.bk,solver.u)
gemv!('N',-1.0,solver.A,solver.zk,1.0,solver.bk)
gemv!('N',-1.0,solver.A,solver.x,1.0,solver.bk)

# solve min ɛ|x|²+|W*A*x-W*bk|² with weightingmatrix W=diag(wᵢ), i=1,...,M.
for l=1:solver.iterationsInner
Expand All @@ -155,9 +120,9 @@ function iterate(solver::DaxConstrained, iteration::Int=0)
solver.yl[j] += solver.αl*solver.ɛw[i]
end

#Lent-Censor scheme for ensuring B(xl+zk) >= 0
#Lent-Censor scheme for ensuring B(xl+x) >= 0
# copyto!(solver.δc,solver.xl)
# BLAS.axpy!(1.0,solver.zk,solver.δc)
# BLAS.axpy!(1.0,solver.x,solver.δc)
# lmul!(solver.B, solver.δc)
# lentcensormin!(solver.δc,solver.yc)
#
Expand All @@ -167,8 +132,8 @@ function iterate(solver::DaxConstrained, iteration::Int=0)
# BLAS.axpy!(1.0,solver.solver.δc,solver.xl) # xl += Bᵀ*δc

#Lent-Censor scheme for solving Bx >= 0
# bc = xl + zk
copyto!(solver.bc,solver.zk)
# bc = xl + x
copyto!(solver.bc,solver.x)
BLAS.axpy!(1.0,solver.xl,solver.bc)

for i=1:size(solver.B,2)
Expand All @@ -179,7 +144,7 @@ function iterate(solver::DaxConstrained, iteration::Int=0)
kaczmarz_update!(solver.B,solver.bc,i,δ) # update bc
end
end
BLAS.axpy!(1.0,solver.xl,solver.zk) # zk += xl
BLAS.axpy!(1.0,solver.xl,solver.x) # x += xl

# reset xl and yl for next Kaczmarz run
rmul!(solver.xl,0.0)
Expand Down
67 changes: 15 additions & 52 deletions src/DAXKaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct DaxKaczmarz{matT,T,U} <: AbstractRowActionSolver
denom::Vector{U}
rowindex::Vector{Int64}
sumrowweights::Vector{Float64}
zk::Vector{T}
x::Vector{T}
bk::Vector{T}
xl::Vector{T}
yl::Vector{T}
Expand Down Expand Up @@ -63,7 +63,7 @@ function DaxKaczmarz(A
else
u = zeros(eltype(A),M)
end
zk = zeros(eltype(A),N)
x = zeros(eltype(A),N)
bk = zeros(eltype(A),M)
xl = zeros(eltype(A),N)
yl = zeros(eltype(A),M)
Expand All @@ -80,73 +80,36 @@ function DaxKaczmarz(A
if !isempty(reg) && !isnothing(sparseTrafo)
reg = map(r -> TransformedRegularization(r, sparseTrafo), reg)
end
return DaxKaczmarz(A,u,reg, Float64(λ), denom,rowindex,sumrowweights,zk,bk,xl,yl,εw,τl,αl
return DaxKaczmarz(A,u,reg, Float64(λ), denom,rowindex,sumrowweights,x,bk,xl,yl,εw,τl,αl
,T.(weights) ,iterations,iterationsInner)
end

function init!(solver::DaxKaczmarz
; A::matT=solver.A
, λ::Real=solver.λ
, u::Vector{T}=eltype(A)[]
, zk::Vector{T}=eltype(A)[]
, weights::Vector{Float64}=solver.weights) where {matT,T}
function init!(solver::DaxKaczmarz, b; x0 = 0)
solver.u .= b
solver.x .= x0

if A != solver.A
solver.sumrowweights, solver.denom, solver.rowindex = initkaczmarzdax(A,solver.λ,solver.weights)
end
solver.λ = Float64(λ)

solver.u[:] .= u
solver.weights=weights

# start vector
if isempty(zk)
solver.zk[:] .= zeros(T,size(A,2))
else
solver.zk[:] .= x
end

solver.bk[:] .= zero(T)
solver.xl[:] .= zero(T)
solver.yl[:] .= zero(T)
solver.αl = zero(T) #temporary scalar
solver.τl = zero(T) #temporary scalar
solver.bk .= 0
solver.xl .= 0
solver.yl .= 0
solver.αl = 0
solver.τl = 0

for i=1:length(solver.rowindex)
j = solver.rowindex[i]
solver.ɛw[i] = sqrt(solver.λ)/weights[j]
solver.ɛw[i] = sqrt(solver.λ)/solver.weights[j]
end
end

function solve(solver::DaxKaczmarz, u::Vector{T}; λ::Real=solver.λ
, A::matT=solver.A, startVector::Vector{T}=eltype(A)[]
, weights::Vector=solver.weights
, solverInfo=nothing, kargs...) where {T,matT}

# initialize solver parameters
init!(solver; A=A, λ=λ, u=u, zk=startVector, weights=weights)

# log solver information
solverInfo != nothing && storeInfo(solverInfo,solver.zk,norm(bk))

# perform CGNR iterations
for (iteration, item) = enumerate(solver)
solverInfo != nothing && storeInfo(solverInfo,solver.zk,norm(bk))
end

return solver.zk
end

function iterate(solver::DaxKaczmarz, iteration::Int=0)
if done(solver,iteration)
for r in solver.reg
prox!(r, solver.zk)
prox!(r, solver.x)
end
return nothing
end

copyto!(solver.bk, solver.u)
gemv!('N',-1.0,solver.A,solver.zk,1.0,solver.bk)
gemv!('N',-1.0,solver.A,solver.x,1.0,solver.bk)

# solve min ɛ|x|²+|W*A*x-W*bk|² with weightingmatrix W=diag(wᵢ), i=1,...,M.
for l=1:length(solver.rowindex)*solver.iterationsInner
Expand All @@ -158,7 +121,7 @@ function iterate(solver::DaxKaczmarz, iteration::Int=0)
solver.yl[j] += solver.αl*solver.ɛw[i]
end

BLAS.axpy!(1.0,solver.xl,solver.zk) # zk += xl
BLAS.axpy!(1.0,solver.xl,solver.x) # x += xl
# reset xl and yl for next Kaczmarz run
rmul!(solver.xl,0.0)
rmul!(solver.yl,0.0)
Expand Down
Loading

0 comments on commit 5bb0cd8

Please sign in to comment.