Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dispatch on proximal map to correctly convert lambda to real(eltype(x)) #84

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Regularization/PlugAndPlayRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization
end
PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape)

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Complex{T}}, λ::T) where {T <: Real}
if self.ignoreIm
x[:] = prox!(self, real.(x), λ) + imag.(x) * one(T)im
else
Expand All @@ -33,7 +33,7 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) whe
return x
end

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T <: Real}

if λ != self.λ && (λ < 0.0 || λ > 1.0)
temp = clamp(λ, zero(T), one(T))
Expand Down
6 changes: 3 additions & 3 deletions src/Regularization/Regularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ norm(reg::AbstractParameterizedRegularization, x::AbstractArray) = norm(reg, x,
return the regularization parameter `λ` of `reg`
"""
λ(reg::AbstractParameterizedRegularization) = reg.λ
# Conversion
prox!(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = prox!(reg, x, convert(T, λ))
norm(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = norm(reg, x, convert(T, λ))
# Conversion (https://github.com/JuliaLang/julia/issues/52978#issuecomment-1900698430)
prox!(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = prox!(reg, x, convert(T, λ))
norm(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = norm(reg, x, convert(T, λ))

"""
prox!(regType::Type{<:AbstractParameterizedRegularization}, x, λ; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ end

performs soft-thresholding - i.e. proximal map for the Lasso problem.
"""
function prox!(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
ε = eps(typeof(λ))
function prox!(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
ε = eps(T)
x .= max.((abs.(x).-λ),0) .* (x.+ε)./(abs.(x).+ε)
return x
end
Expand All @@ -26,7 +26,7 @@ end

returns the value of the L1-regularization term.
"""
function norm(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
l1Norm = λ*norm(x,1)
return l1Norm
end
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

performs the proximal map for Tikhonov regularization.
"""
function prox!(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
x[:] .*= 1. ./ (1. .+ 2. .*λ)#*x
return x
end
Expand All @@ -25,8 +25,8 @@

returns the value of the L2-regularization term
"""
norm(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)^2
function norm(::L2Regularization, x::AbstractArray{Tc}, λ::AbstractArray{T}) where {T, Tc <: Union{T, Complex{T}}}
norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = λ*norm(x,2)^2
function norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::AbstractArray{T}) where {T <: Real}

Check warning on line 29 in src/proximalMaps/ProxL2.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxL2.jl#L29

Added line #L29 was not covered by tests
res = zero(real(eltype(x)))
for i in eachindex(x)
res+= λ[i]*abs2(x[i])
Expand Down
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL21.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ L21Regularization(λ; slices::Int64 = 1, kargs...) = L21Regularization(λ, slice

performs group-soft-thresholding for l1/l2-regularization.
"""
function prox!(reg::L21Regularization, x::AbstractArray{Tc},λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}},λ::T) where {T <: Real}
return proxL21!(x, λ, reg.slices)
end

function proxL21!(x::AbstractArray{T}, λ::Float64, slices::Int64) where T
function proxL21!(x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T, slices::Int64) where T
sliceLength = div(length(x),slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]
x[:] = [ x[i]*max( (groupNorm[mod1(i,sliceLength)]-λ)/groupNorm[mod1(i,sliceLength)],0 ) for i=1:length(x)]
Expand All @@ -39,7 +39,7 @@ end

return the value of the L21-regularization term.
"""
function norm(reg::L21Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
sliceLength = div(length(x),reg.slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]
return λ*norm(groupNorm,1)
Expand Down
5 changes: 3 additions & 2 deletions src/proximalMaps/ProxLLR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

performs the proximal map for LLR regularization using singular-value-thresholding
"""
function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}}
function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T <: Real}
Tc = eltype(x)

Check warning on line 32 in src/proximalMaps/ProxLLR.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxLLR.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
Expand Down Expand Up @@ -92,7 +93,7 @@

returns the value of the LLR-regularization term.
"""
function norm(reg::LLRRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}

Check warning on line 96 in src/proximalMaps/ProxLLR.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxLLR.jl#L96

Added line #L96 was not covered by tests
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
Expand Down
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxNuclear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ NuclearRegularization(λ; svtShape::NTuple=[], kargs...) = NuclearRegularization

performs singular value soft-thresholding - i.e. the proximal map for the nuclear norm regularization.
"""
function prox!(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd(reshape(x, reg.svtShape))
prox!(L1Regularization, S, λ)
x[:] = vec(U*Matrix(Diagonal(S))*V')
Expand All @@ -35,7 +35,7 @@ end

returns the value of the nuclear norm regularization term.
"""
function norm(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd( reshape(x, reg.svtShape) )
return λ*norm(S,1)
end
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxTV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end

Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise.
"""
prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)
prox!(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)

function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims
return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims
Expand Down Expand Up @@ -132,7 +132,7 @@ end

returns the value of the TV-regularization term.
"""
function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}}
∇ = GradientOp(Tc; shape=reg.shape, dims=reg.dims)
function norm(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
∇ = GradientOp(eltype(x); shape=reg.shape, dims=reg.dims)
return λ * norm(∇ * x, 1)
end
12 changes: 7 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ using RegularizedLeastSquares, LinearAlgebra, RegularizedLeastSquares.LinearOper
using Random, Test
using FFTW

include("testCreation.jl")
include("testKaczmarz.jl")
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
@testset "RegularizedLeastSquares" begin
include("testCreation.jl")
include("testKaczmarz.jl")
include("testProxMaps.jl")
include("testSolvers.jl")
include("testRegularization.jl")
end
41 changes: 30 additions & 11 deletions test/testProxMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,35 @@ function testLLR_3D(shape=(32,32,32,80),blockSize=(4,4,4);σ=0.05)
# @test 0.5*norm(xNoisy-x_llr)^2+normLLR(x_llr,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false) <= normLLR(xNoisy,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false)
end

function testConversion()
for (xType, lambdaType) in [(Float32, Float64), (Float64, Float32), (Complex{Float32}, Float64), (Complex{Float64}, Float32)]
for prox in [L1Regularization, L21Regularization, L2Regularization, LLRRegularization, NuclearRegularization, TVRegularization]
@info "Test λ conversion for prox!($prox, $xType, $lambdaType)"
@test try prox!(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
@test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
end
end
end

@testset "Proximal Maps" begin
testL2Prox()
testL1Prox()
testL21Prox()
testTVprox()
testDirectionalTVprox()
testPositive()
testProj()
testNuclear()
testLLR()
#testLLROverlapping()
testLLR_3D()
@testset "L2 Prox" testL2Prox()
@testset "L1 Prox" testL1Prox()
@testset "L21 Prox" testL21Prox()
@testset "TV Prox" testTVprox()
@testset "TV Prox Directional" testDirectionalTVprox()
@testset "Positive Prox" testPositive()
@testset "Projection Prox" testProj()
@testset "Nuclear Prox" testNuclear()
#@testset "LLR Prox" testLLR()
#@testset "LLR Prox Overlapping" #testLLROverlapping()
#@testset "LLR Prox 3D" testLLR_3D()
@testset "Prox Lambda Conversion" testConversion()
end
Loading