Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dispatch on proximal map to correctly convert lambda to real(eltype(x)) #84

Merged
merged 5 commits into from
May 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix lambda conversion dispatch for complex x
  • Loading branch information
nHackel committed May 24, 2024
commit fad7bbe63eded534c29ec6e916e9f31d12fc63eb
4 changes: 2 additions & 2 deletions src/Regularization/PlugAndPlayRegularization.jl
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization
end
PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape)

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T <: Real, Tc <: Complex{T}}
if self.ignoreIm
x[:] = prox!(self, real.(x), λ) + imag.(x) * one(T)im
else
@@ -33,7 +33,7 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) whe
return x
end

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T}
function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T <: Real}

if λ != self.λ && (λ < 0.0 || λ > 1.0)
temp = clamp(λ, zero(T), one(T))
6 changes: 2 additions & 4 deletions src/Regularization/Regularization.jl
Original file line number Diff line number Diff line change
@@ -28,10 +28,8 @@ return the regularization parameter `λ` of `reg`
"""
λ(reg::AbstractParameterizedRegularization) = reg.λ
# Conversion (https://github.com/JuliaLang/julia/issues/52978#issuecomment-1900698430)
prox!(reg::R, x::AbstractArray{T}, λ::otherT) where {R<:AbstractParameterizedRegularization, T, otherT} = prox!(reg, x, convert(T, λ))
prox!(reg::R, x::AbstractArray{Complex{T}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T, otherT} = prox!(reg, x, convert(T, λ))
norm(reg::R, x::AbstractArray{T}, λ::otherT) where {R<:AbstractParameterizedRegularization, T, otherT} = norm(reg, x, convert(T, λ))
norm(reg::R, x::AbstractArray{Complex{T}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T, otherT} = norm(reg, x, convert(T, λ))
prox!(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = prox!(reg, x, convert(T, λ))
norm(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = norm(reg, x, convert(T, λ))

"""
prox!(regType::Type{<:AbstractParameterizedRegularization}, x, λ; kwargs...)
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxL1.jl
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ end

performs soft-thresholding - i.e. proximal map for the Lasso problem.
"""
function prox!(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function prox!(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
ε = eps(T)
x .= max.((abs.(x).-λ),0) .* (x.+ε)./(abs.(x).+ε)
return x
@@ -26,7 +26,7 @@ end

returns the value of the L1-regularization term.
"""
function norm(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function norm(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
l1Norm = λ*norm(x,1)
return l1Norm
end
6 changes: 3 additions & 3 deletions src/proximalMaps/ProxL2.jl
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ end

performs the proximal map for Tikhonov regularization.
"""
function prox!(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function prox!(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
x[:] .*= 1. ./ (1. .+ 2. .*λ)#*x
return x
end
@@ -25,8 +25,8 @@ end

returns the value of the L2-regularization term
"""
norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T} = λ*norm(x,2)^2
function norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::AbstractArray{T}) where {T}
norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = λ*norm(x,2)^2
function norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::AbstractArray{T}) where {T <: Real}
res = zero(real(eltype(x)))
for i in eachindex(x)
res+= λ[i]*abs2(x[i])
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxL21.jl
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ L21Regularization(λ; slices::Int64 = 1, kargs...) = L21Regularization(λ, slice

performs group-soft-thresholding for l1/l2-regularization.
"""
function prox!(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}},λ::T) where {T}
function prox!(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}},λ::T) where {T <: Real}
return proxL21!(x, λ, reg.slices)
end

@@ -39,7 +39,7 @@ end

return the value of the L21-regularization term.
"""
function norm(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function norm(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
sliceLength = div(length(x),reg.slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]
return λ*norm(groupNorm,1)
5 changes: 3 additions & 2 deletions src/proximalMaps/ProxLLR.jl
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@ LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ -

performs the proximal map for LLR regularization using singular-value-thresholding
"""
function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T}
function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T <: Real}
Tc = eltype(x)
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
@@ -92,7 +93,7 @@ end

returns the value of the LLR-regularization term.
"""
function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxNuclear.jl
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ NuclearRegularization(λ; svtShape::NTuple=[], kargs...) = NuclearRegularization

performs singular value soft-thresholding - i.e. the proximal map for the nuclear norm regularization.
"""
function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd(reshape(x, reg.svtShape))
prox!(L1Regularization, S, λ)
x[:] = vec(U*Matrix(Diagonal(S))*V')
@@ -35,7 +35,7 @@ end

returns the value of the nuclear norm regularization term.
"""
function norm(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T}
function norm(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd( reshape(x, reg.svtShape) )
return λ*norm(S,1)
end
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxTV.jl
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ end

Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise.
"""
prox!(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)
prox!(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)

function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims
return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims
@@ -132,7 +132,7 @@ end

returns the value of the TV-regularization term.
"""
function norm(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T<:Real}
function norm(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
∇ = GradientOp(eltype(x); shape=reg.shape, dims=reg.dims)
return λ * norm(∇ * x, 1)
end
47 changes: 26 additions & 21 deletions test/testProxMaps.jl
Original file line number Diff line number Diff line change
@@ -278,29 +278,34 @@ function testLLR_3D(shape=(32,32,32,80),blockSize=(4,4,4);σ=0.05)
end

function testConversion()
xF32 = zeros(Float32, 10)
xF64 = zeros(Float64, 10)
# None should throw errors
for prox in [L1Regularization, L21Regularization, L2Regularization, LLRRegularization, NuclearRegularization, TVRegularization]
@info "Test λ conversion for $prox"
@test prox!(prox, xF32, Float64(0.0); shape = (2, 5), svtShape = (2,5)) isa Vector skip = in(prox, [LLRRegularization, NuclearRegularization])
@test prox!(prox, xF64, Float32(0.0); shape = (2, 5), svtShape = (2,5)) isa Vector skip = in(prox, [LLRRegularization, NuclearRegularization])
@test RegularizedLeastSquares.norm(prox, xF32, Float64(0.0); shape = (2, 5), svtShape = (2,5)) isa Number skip = in(prox, [LLRRegularization, NuclearRegularization])
@test RegularizedLeastSquares.norm(prox, xF64, Float32(0.0); shape = (2, 5), svtShape = (2,5)) isa Number skip = in(prox, [LLRRegularization, NuclearRegularization])
for (xType, lambdaType) in [(Float32, Float64), (Float64, Float32), (Complex{Float32}, Float64), (Complex{Float64}, Float32)]
for prox in [L1Regularization, L21Regularization, L2Regularization, LLRRegularization, NuclearRegularization, TVRegularization]
@info "Test λ conversion for prox!($prox, $xType, $lambdaType)"
@test try prox!(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
@test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
end
end
end

@testset "Proximal Maps" begin
testL2Prox()
testL1Prox()
testL21Prox()
testTVprox()
testDirectionalTVprox()
testPositive()
testProj()
testNuclear()
testLLR()
#testLLROverlapping()
testLLR_3D()
testConversion()
@testset "L2 Prox" testL2Prox()
@testset "L1 Prox" testL1Prox()
@testset "L21 Prox" testL21Prox()
@testset "TV Prox" testTVprox()
@testset "TV Prox Directional" testDirectionalTVprox()
@testset "Positive Prox" testPositive()
@testset "Projection Prox" testProj()
@testset "Nuclear Prox" testNuclear()
#@testset "LLR Prox" testLLR()
#@testset "LLR Prox Overlapping" #testLLROverlapping()
#@testset "LLR Prox 3D" testLLR_3D()
@testset "Prox Lambda Conversion" testConversion()
end
Loading