Skip to content

Commit

Permalink
Merge branch 'master' into gpuStates
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed May 28, 2024
2 parents 3f2db97 + fd4de54 commit ba840ec
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RegularizedLeastSquares"
uuid = "1e9c538a-f78c-5de5-8ffb-0b6dbe892d23"
authors = ["Tobias Knopp <[email protected]>"]
version = "0.13.1-DEV"
version = "0.14.1"

[deps]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
5 changes: 2 additions & 3 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ mutable struct CGNRState{T, Tc, vecTc} <: AbstractSolverState{CGNR} where {T, Tc
end

"""
CGNR(A; AHA = A' * A, reg = L2Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), weights = similar(AHA, 0), iterations = 10, relTol = eps(real(eltype(AHA))))
CGNR( ; AHA = , reg = L2Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), weights = similar(AHA, 0), iterations = 10, relTol = eps(real(eltype(AHA))))
CGNR(A; AHA = A' * A, reg = L2Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), iterations = 10, relTol = eps(real(eltype(AHA))))
CGNR( ; AHA = , reg = L2Regularization(zero(real(eltype(AHA)))), normalizeReg = NoNormalization(), iterations = 10, relTol = eps(real(eltype(AHA))))
creates an `CGNR` object for the forward operator `A` or normal operator `AHA`.
Expand All @@ -38,7 +38,6 @@ creates an `CGNR` object for the forward operator `A` or normal operator `AHA`.
* `AHA` - normal operator is optional if `A` is supplied
* `reg::AbstractParameterizedRegularization` - regularization term; can also be a vector of regularization terms
* `normalizeReg::AbstractRegularizationNormalization` - regularization normalization scheme; options are `NoNormalization()`, `MeasurementBasedNormalization()`, `SystemMatrixBasedNormalization()`
* `weights::AbstactVector` - weights for the data term; must be of same length and type as the data term
* `iterations::Int` - maximum number of iterations
* `relTol::Real` - tolerance for stopping criterion
Expand Down
2 changes: 1 addition & 1 deletion src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mutable struct KaczmarzState{T, vecT <: AbstractArray{T}} <: AbstractSolverState
end

"""
Kaczmarz(A; reg = L2Regularization(0), normalizeReg = NoNormalization(), weights=nothing, randomized=false, subMatrixFraction=0.15, shuffleRows=false, seed=1234, iterations=10, regMatrix=nothing)
Kaczmarz(A; reg = L2Regularization(0), normalizeReg = NoNormalization(), randomized=false, subMatrixFraction=0.15, shuffleRows=false, seed=1234, iterations=10)
Creates a Kaczmarz object for the forward operator `A`.
Expand Down
4 changes: 2 additions & 2 deletions src/Regularization/PlugAndPlayRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization
end
PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape)

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Complex{T}}, λ::T) where {T <: Real}
if self.ignoreIm
copyto!(x, prox!(self, real.(x), λ) + imag.(x) * one(T)im)
else
Expand All @@ -33,7 +33,7 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) whe
return x
end

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T <: Real}

if λ != self.λ &&< 0.0 || λ > 1.0)
temp = clamp(λ, zero(T), one(T))
Expand Down
6 changes: 3 additions & 3 deletions src/Regularization/Regularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ norm(reg::AbstractParameterizedRegularization, x::AbstractArray) = norm(reg, x,
return the regularization parameter `λ` of `reg`
"""
λ(reg::AbstractParameterizedRegularization) = reg.λ
# Conversion
prox!(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = prox!(reg, x, convert(T, λ))
norm(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = norm(reg, x, convert(T, λ))
# Conversion (https://github.com/JuliaLang/julia/issues/52978#issuecomment-1900698430)
prox!(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = prox!(reg, x, convert(T, λ))
norm(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = norm(reg, x, convert(T, λ))

"""
prox!(regType::Type{<:AbstractParameterizedRegularization}, x, λ; kwargs...)
Expand Down
27 changes: 17 additions & 10 deletions src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ See also [`isapplicable`](@ref), [`linearSolverList`](@ref).
"""
applicableSolverList(args...) = filter(solver -> isapplicable(solver, args...), linearSolverListReal())

function filterKwargs(T::Type, kwargs)
table = methods(T)
keywords = union(Base.kwarg_decl.(table)...)
filtered = filter(in(keywords), keys(kwargs))

if length(filtered) < length(kwargs)
filteredout = filter(!in(keywords), keys(kwargs))
@warn "The following arguments were passed but filtered out: $(join(filteredout, ", ")). Please watch closely if this introduces unexpexted behaviour in your code."
end

return [key=>kwargs[key] for key in filtered]
end

"""
createLinearSolver(solver::AbstractLinearSolver, A; kargs...)
Expand All @@ -253,18 +266,12 @@ regularized linear systems. All solvers return an approximate solution to Ax = b
TODO: give a hint what solvers are available
"""
function createLinearSolver(solver::Type{T}, A; kargs...) where {T<:AbstractLinearSolver}
table = methods(T)
keywords = union(Base.kwarg_decl.(table)...)
filtered = filter(in(keywords), keys(kargs))
return solver(A; [key=>kargs[key] for key in filtered]...)
function createLinearSolver(solver::Type{T}, A; kwargs...) where {T<:AbstractLinearSolver}
return solver(A; filterKwargs(T, kwargs)...)
end

function createLinearSolver(solver::Type{T}; AHA, kargs...) where {T<:AbstractLinearSolver}
table = methods(T)
keywords = union(Base.kwarg_decl.(table)...)
filtered = filter(in(keywords), keys(kargs))
return solver(; [key=>kargs[key] for key in filtered]..., AHA = AHA)
function createLinearSolver(solver::Type{T}; AHA, kwargs...) where {T<:AbstractLinearSolver}
return solver(; filterKwargs(T, kwargs)..., AHA = AHA)
end

end
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxL1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
performs soft-thresholding - i.e. proximal map for the Lasso problem.
"""
function prox!(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
ε = eps(T)
x .= max.((abs.(x).-λ),0) .* (x.+ε)./(abs.(x).+ε)
return x
Expand All @@ -26,7 +26,7 @@ end
returns the value of the L1-regularization term.
"""
function norm(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
l1Norm = λ*norm(x,1)
return l1Norm
end
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
performs the proximal map for Tikhonov regularization.
"""
function prox!(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
x[:] .*= 1. ./ (1. .+ 2. .*λ)#*x
return x
end
Expand All @@ -25,8 +25,8 @@ end
returns the value of the L2-regularization term
"""
norm(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)^2
function norm(::L2Regularization, x::AbstractArray{Tc}, λ::AbstractArray{T}) where {T, Tc <: Union{T, Complex{T}}}
norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = λ*norm(x,2)^2
function norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::AbstractArray{T}) where {T <: Real}
res = zero(real(eltype(x)))
for i in eachindex(x)
res+= λ[i]*abs2(x[i])
Expand Down
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL21.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ L21Regularization(λ; slices::Int64 = 1, kargs...) = L21Regularization(λ, slice
performs group-soft-thresholding for l1/l2-regularization.
"""
function prox!(reg::L21Regularization, x::AbstractArray{Tc},λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}},λ::T) where {T <: Real}
return proxL21!(x, λ, reg.slices)
end

function proxL21!(x::AbstractArray{T}, λ::Float64, slices::Int64) where T
function proxL21!(x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T, slices::Int64) where T
sliceLength = div(length(x),slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]
copyto!(x, [ x[i]*max( (groupNorm[mod1(i,sliceLength)]-λ)/groupNorm[mod1(i,sliceLength)],0 ) for i=1:length(x)])
Expand All @@ -39,7 +39,7 @@ end
return the value of the L21-regularization term.
"""
function norm(reg::L21Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
sliceLength = div(length(x),reg.slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]
return λ*norm(groupNorm,1)
Expand Down
5 changes: 3 additions & 2 deletions src/proximalMaps/ProxLLR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ -
performs the proximal map for LLR regularization using singular-value-thresholding
"""
function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}}
function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T <: Real}
Tc = eltype(x)
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
Expand Down Expand Up @@ -92,7 +93,7 @@ end
returns the value of the LLR-regularization term.
"""
function norm(reg::LLRRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
Expand Down
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxNuclear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ NuclearRegularization(λ; svtShape::NTuple=[], kargs...) = NuclearRegularization
performs singular value soft-thresholding - i.e. the proximal map for the nuclear norm regularization.
"""
function prox!(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd(reshape(x, reg.svtShape))
prox!(L1Regularization, S, λ)
copyto!(x, vec(U*Matrix(Diagonal(S))*V'))
Expand All @@ -35,7 +35,7 @@ end
returns the value of the nuclear norm regularization term.
"""
function norm(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd( reshape(x, reg.svtShape) )
return λ*norm(S,1)
end
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxTV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise.
"""
prox!(reg::TVRegularization, x::AbstractVector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(reg, x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)
prox!(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = proxTV!(reg, x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)

function proxTV!(reg, x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims
return proxTV!(reg, x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims
Expand Down Expand Up @@ -149,7 +149,7 @@ end
returns the value of the TV-regularization term.
"""
function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}}
= GradientOp(Tc; shape=reg.shape, dims=reg.dims)
function norm(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
= GradientOp(eltype(x); shape=reg.shape, dims=reg.dims)
return λ * norm(∇ * x, 1)
end
11 changes: 7 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ using JLArrays

arrayTypes = [Array, JLArray]

include("testKaczmarz.jl")
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
@testset "RegularizedLeastSquares" begin
include("testCreation.jl")
include("testKaczmarz.jl")
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
end
3 changes: 3 additions & 0 deletions test/testCreation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@testset "Creation of solvers" begin
@test_logs (:warn, Regex("The following arguments were passed but filtered out: testKwarg*")) createLinearSolver(Kaczmarz, zeros(42, 42), testKwarg=1337)
end
41 changes: 30 additions & 11 deletions test/testProxMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,35 @@ function testLLR_3D(shape=(32,32,32,80),blockSize=(4,4,4);σ=0.05)
# @test 0.5*norm(xNoisy-x_llr)^2+normLLR(x_llr,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false) <= normLLR(xNoisy,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false)
end

function testConversion()
for (xType, lambdaType) in [(Float32, Float64), (Float64, Float32), (Complex{Float32}, Float64), (Complex{Float64}, Float32)]
for prox in [L1Regularization, L21Regularization, L2Regularization, LLRRegularization, NuclearRegularization, TVRegularization]
@info "Test λ conversion for prox!($prox, $xType, $lambdaType)"
@test try prox!(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
@test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
end
end
end

@testset "Proximal Maps" begin
testL2Prox()
testL1Prox()
testL21Prox()
testTVprox()
testDirectionalTVprox()
testPositive()
testProj()
testNuclear()
testLLR()
#testLLROverlapping()
testLLR_3D()
@testset "L2 Prox" testL2Prox()
@testset "L1 Prox" testL1Prox()
@testset "L21 Prox" testL21Prox()
@testset "TV Prox" testTVprox()
@testset "TV Prox Directional" testDirectionalTVprox()
@testset "Positive Prox" testPositive()
@testset "Projection Prox" testProj()
@testset "Nuclear Prox" testNuclear()
#@testset "LLR Prox" testLLR()
#@testset "LLR Prox Overlapping" #testLLROverlapping()
#@testset "LLR Prox 3D" testLLR_3D()
@testset "Prox Lambda Conversion" testConversion()
end

0 comments on commit ba840ec

Please sign in to comment.