Skip to content

Commit

Permalink
Merge branch 'nestedRegTerms' of github.com:tknopp/RegularizedLeastSq…
Browse files Browse the repository at this point in the history
…uares.jl into nestedRegTerms
  • Loading branch information
nHackel committed Nov 6, 2023
2 parents 5ac1667 + db13951 commit d8c9d91
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 23 deletions.
11 changes: 6 additions & 5 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2));
# TODO: The constructor is not type stable

# unify Floating types
if typeof(ρ) <: Number
ρ_vec = [real(T).(ρ)]
else
ρ_vec = real(T).(ρ)
end
absTol = real(T)(absTol)
relTol = real(T)(relTol)
tolInner = real(T)(tolInner)
Expand All @@ -102,6 +97,12 @@ function ADMM(A::matT, x::Vector{T}=zeros(eltype(A),size(A,2));
end
end
regTrafo = identity.(regTrafo)

if typeof(ρ) <: Number
ρ_vec = [real(T).(ρ) for i = 1:length(reg)]
else
ρ_vec = real(T).(ρ)
end

xᵒˡᵈ = similar(x)

Expand Down
3 changes: 2 additions & 1 deletion src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ function Kaczmarz(S; b=nothing, reg::Vector{<:AbstractRegularization} = [L2Regul
deleteat!(reg, idx)

indices = findsinks(AbstractProjectionRegularization, reg)
other = [reg[i] for i in indices]
other = AbstractRegularization[reg[i] for i in indices]
deleteat!(reg, indices)
if length(reg) == 1
pushfirst!(other, reg[1])
elseif length(reg) > 1
error("Kaczmarz does not allow for more than one additional regularization term, found $(length(reg))")
end
other = identity.(other)


# make sure weights are not empty
Expand Down
27 changes: 18 additions & 9 deletions src/Regularization/ScaledRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,31 @@ end
nested(reg::FixedScaledRegularization) = reg.reg
factor(reg::FixedScaledRegularization) = reg.factor

#=

export AutoScaledRegularization
mutable struct AutoScaledRegularization{T, R<:AbstractParameterizedRegularization{T}} <: AbstractScaledRegularization{T}
mutable struct AutoScaledRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
reg::R
factor::Union{Nothing, T}
AutoScaledRegularization(reg::R) where {T, R<:AbstractParameterizedRegularization{T}} = new{T,R}(reg, nothing)
AutoScaledRegularization(reg::R) where {T, R <: AbstractParameterizedRegularization{T}} = new{T, R, R}(reg, nothing)
AutoScaledRegularization(reg::R) where {T, RN <: AbstractParameterizedRegularization{T}, R<:AbstractNestedRegularization{RN}} = new{T, RN, R}(reg, nothing)
end
initFactor!(reg::AutoScaledRegularization, x::AbstractArray) = reg.factor = maximum(abs.(x))
nested(reg::AutoScaledRegularization) = reg.reg
# A bit hacky: Factor can only be computed once x is seen, therefore hide factor in λ and silently add it in prox!/norm calls
λ(reg::AutoScaledRegularization) = λ(reg.reg)
factor(reg::AutoScaledRegularization) = isnothing(reg.factor) ? 1.0 : reg.factor
function prox!(reg::AutoScaledRegularization, x, λ)
isnothing(reg.factor) && initFactor!(reg, x)
return prox!(reg.reg, x, λ * reg.factor)
if isnothing(reg.factor)
initFactor!(reg, x)
return prox!(reg.reg, x, λ * reg.factor)
else
return prox!(reg.reg, x, λ)
end
end
function norm(reg::AutoScaledRegularization, x, λ)
isnothing(reg.factor) && initFactor!(reg, x)
return norm(reg.reg, x, λ * reg.factor)
end=#
if isnothing(reg.factor)
initFactor!(reg, x)
return norm(reg.reg, x, λ * reg.factor)
else
return norm(reg.reg, x, λ)
end
end
15 changes: 7 additions & 8 deletions src/SplitBregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;
, relTol::Float64=eps()
, tolInner::Float64=1.e-6
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, kargs...) where {matT, vecT<:AbstractVector}
, kargs...) where {T, matT, vecT<:AbstractVector{T}}

reg = vec(reg)
indices = findsinks(AbstractProjectionRegularization, reg)
Expand All @@ -90,6 +90,12 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;
end
regTrafo = identity.(regTrafo)

# make sure that ρ is a vector
if typeof(ρ) <: Number
ρ_vec = [real(T).(ρ) for i = 1:length(reg)]
else
ρ_vec = real(T).(ρ)
end


if b==nothing
Expand Down Expand Up @@ -125,13 +131,6 @@ function SplitBregman(A::matT, x::vecT=zeros(eltype(A),size(A,2)), b=nothing;

iter_cnt = 1

# make sure that ρ is a vector
if typeof(ρ) <: Real
ρ_vec = similar(x, real(eltype(x)), 1)
ρ_vec .= ρ
else
ρ_vec = typeof(real.(x))(ρ)
end

# normalization parameters
reg = normalize(SplitBregman, normalizeReg, vec(reg), A, nothing)
Expand Down

0 comments on commit d8c9d91

Please sign in to comment.