diff --git a/src/Regularization/PlugAndPlayRegularization.jl b/src/Regularization/PlugAndPlayRegularization.jl index ae2858fd..2772afb2 100644 --- a/src/Regularization/PlugAndPlayRegularization.jl +++ b/src/Regularization/PlugAndPlayRegularization.jl @@ -24,7 +24,7 @@ struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization end PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape) -function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}} +function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Complex{T}}, λ::T) where {T <: Real} if self.ignoreIm x[:] = prox!(self, real.(x), λ) + imag.(x) * one(T)im else @@ -33,7 +33,7 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) whe return x end -function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T} +function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) where {T <: Real} if λ != self.λ && (λ < 0.0 || λ > 1.0) temp = clamp(λ, zero(T), one(T)) diff --git a/src/Regularization/Regularization.jl b/src/Regularization/Regularization.jl index 760114b5..768c37a5 100644 --- a/src/Regularization/Regularization.jl +++ b/src/Regularization/Regularization.jl @@ -27,9 +27,9 @@ norm(reg::AbstractParameterizedRegularization, x::AbstractArray) = norm(reg, x, return the regularization parameter `λ` of `reg` """ λ(reg::AbstractParameterizedRegularization) = reg.λ -# Conversion -prox!(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = prox!(reg, x, convert(T, λ)) -norm(reg::AbstractParameterizedRegularization, x::AbstractArray{Tc}, λ) where {T, Tc<:Union{T, Complex{T}}} = norm(reg, x, convert(T, λ)) +# Conversion (https://github.com/JuliaLang/julia/issues/52978#issuecomment-1900698430) +prox!(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = prox!(reg, x, convert(T, λ)) +norm(reg::R, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::otherT) where {R<:AbstractParameterizedRegularization, T <: Real, otherT} = norm(reg, x, convert(T, λ)) """ prox!(regType::Type{<:AbstractParameterizedRegularization}, x, λ; kwargs...) diff --git a/src/proximalMaps/ProxL1.jl b/src/proximalMaps/ProxL1.jl index 53c59634..61ca38f3 100644 --- a/src/proximalMaps/ProxL1.jl +++ b/src/proximalMaps/ProxL1.jl @@ -15,8 +15,8 @@ end performs soft-thresholding - i.e. proximal map for the Lasso problem. """ -function prox!(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} - ε = eps(typeof(λ)) +function prox!(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} + ε = eps(T) x .= max.((abs.(x).-λ),0) .* (x.+ε)./(abs.(x).+ε) return x end @@ -26,7 +26,7 @@ end returns the value of the L1-regularization term. """ -function norm(::L1Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function norm(::L1Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} l1Norm = λ*norm(x,1) return l1Norm end diff --git a/src/proximalMaps/ProxL2.jl b/src/proximalMaps/ProxL2.jl index fd1abe08..bc8ecc7c 100644 --- a/src/proximalMaps/ProxL2.jl +++ b/src/proximalMaps/ProxL2.jl @@ -15,7 +15,7 @@ end performs the proximal map for Tikhonov regularization. """ -function prox!(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function prox!(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} x[:] .*= 1. ./ (1. .+ 2. .*λ)#*x return x end @@ -25,8 +25,8 @@ end returns the value of the L2-regularization term """ -norm(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)^2 -function norm(::L2Regularization, x::AbstractArray{Tc}, λ::AbstractArray{T}) where {T, Tc <: Union{T, Complex{T}}} +norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = λ*norm(x,2)^2 +function norm(::L2Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::AbstractArray{T}) where {T <: Real} res = zero(real(eltype(x))) for i in eachindex(x) res+= λ[i]*abs2(x[i]) diff --git a/src/proximalMaps/ProxL21.jl b/src/proximalMaps/ProxL21.jl index 40d2a77f..4452a9f4 100644 --- a/src/proximalMaps/ProxL21.jl +++ b/src/proximalMaps/ProxL21.jl @@ -23,11 +23,11 @@ L21Regularization(λ; slices::Int64 = 1, kargs...) = L21Regularization(λ, slice performs group-soft-thresholding for l1/l2-regularization. """ -function prox!(reg::L21Regularization, x::AbstractArray{Tc},λ::T) where {T, Tc <: Union{T, Complex{T}}} +function prox!(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}},λ::T) where {T <: Real} return proxL21!(x, λ, reg.slices) end -function proxL21!(x::AbstractArray{T}, λ::Float64, slices::Int64) where T +function proxL21!(x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T, slices::Int64) where T sliceLength = div(length(x),slices) groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength] x[:] = [ x[i]*max( (groupNorm[mod1(i,sliceLength)]-λ)/groupNorm[mod1(i,sliceLength)],0 ) for i=1:length(x)] @@ -39,7 +39,7 @@ end return the value of the L21-regularization term. """ -function norm(reg::L21Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function norm(reg::L21Regularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} sliceLength = div(length(x),reg.slices) groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength] return λ*norm(groupNorm,1) diff --git a/src/proximalMaps/ProxLLR.jl b/src/proximalMaps/ProxLLR.jl index f65815ef..232a7ac6 100644 --- a/src/proximalMaps/ProxLLR.jl +++ b/src/proximalMaps/ProxLLR.jl @@ -28,7 +28,8 @@ LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ - performs the proximal map for LLR regularization using singular-value-thresholding """ -function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}} +function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T <: Real} + Tc = eltype(x) shape = reg.shape blockSize = reg.blockSize randshift = reg.randshift @@ -92,7 +93,7 @@ end returns the value of the LLR-regularization term. """ -function norm(reg::LLRRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} shape = reg.shape blockSize = reg.blockSize randshift = reg.randshift diff --git a/src/proximalMaps/ProxNuclear.jl b/src/proximalMaps/ProxNuclear.jl index 0d7fc5dd..e14898dc 100644 --- a/src/proximalMaps/ProxNuclear.jl +++ b/src/proximalMaps/ProxNuclear.jl @@ -23,7 +23,7 @@ NuclearRegularization(λ; svtShape::NTuple=[], kargs...) = NuclearRegularization performs singular value soft-thresholding - i.e. the proximal map for the nuclear norm regularization. """ -function prox!(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} U,S,V = svd(reshape(x, reg.svtShape)) prox!(L1Regularization, S, λ) x[:] = vec(U*Matrix(Diagonal(S))*V') @@ -35,7 +35,7 @@ end returns the value of the nuclear norm regularization term. """ -function norm(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} +function norm(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} U,S,V = svd( reshape(x, reg.svtShape) ) return λ*norm(S,1) end diff --git a/src/proximalMaps/ProxTV.jl b/src/proximalMaps/ProxTV.jl index 48cf39f1..a17d21f0 100644 --- a/src/proximalMaps/ProxTV.jl +++ b/src/proximalMaps/ProxTV.jl @@ -61,7 +61,7 @@ end Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise. """ -prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV) +prox!(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV) function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims @@ -132,7 +132,7 @@ end returns the value of the TV-regularization term. """ -function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}} - ∇ = GradientOp(Tc; shape=reg.shape, dims=reg.dims) +function norm(reg::TVRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} + ∇ = GradientOp(eltype(x); shape=reg.shape, dims=reg.dims) return λ * norm(∇ * x, 1) end diff --git a/test/runtests.jl b/test/runtests.jl index 8240fbd6..00d5f82d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,8 +3,10 @@ using RegularizedLeastSquares, LinearAlgebra, RegularizedLeastSquares.LinearOper using Random, Test using FFTW -include("testCreation.jl") -include("testKaczmarz.jl") -include("testProxMaps.jl") -include("testSolvers.jl") -include("testRegularization.jl") \ No newline at end of file +@testset "RegularizedLeastSquares" begin + include("testCreation.jl") + include("testKaczmarz.jl") + include("testProxMaps.jl") + include("testSolvers.jl") + include("testRegularization.jl") +end \ No newline at end of file diff --git a/test/testProxMaps.jl b/test/testProxMaps.jl index df7bdb64..788f1148 100644 --- a/test/testProxMaps.jl +++ b/test/testProxMaps.jl @@ -277,16 +277,35 @@ function testLLR_3D(shape=(32,32,32,80),blockSize=(4,4,4);σ=0.05) # @test 0.5*norm(xNoisy-x_llr)^2+normLLR(x_llr,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false) <= normLLR(xNoisy,10*σ,shape=shape[1:3],blockSize=blockSize,randshift=false) end +function testConversion() + for (xType, lambdaType) in [(Float32, Float64), (Float64, Float32), (Complex{Float32}, Float64), (Complex{Float64}, Float32)] + for prox in [L1Regularization, L21Regularization, L2Regularization, LLRRegularization, NuclearRegularization, TVRegularization] + @info "Test λ conversion for prox!($prox, $xType, $lambdaType)" + @test try prox!(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5)) + true + catch e + false + end skip = in(prox, [LLRRegularization, NuclearRegularization]) + @test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5)) + true + catch e + false + end skip = in(prox, [LLRRegularization, NuclearRegularization]) + end + end +end + @testset "Proximal Maps" begin - testL2Prox() - testL1Prox() - testL21Prox() - testTVprox() - testDirectionalTVprox() - testPositive() - testProj() - testNuclear() - testLLR() - #testLLROverlapping() - testLLR_3D() + @testset "L2 Prox" testL2Prox() + @testset "L1 Prox" testL1Prox() + @testset "L21 Prox" testL21Prox() + @testset "TV Prox" testTVprox() + @testset "TV Prox Directional" testDirectionalTVprox() + @testset "Positive Prox" testPositive() + @testset "Projection Prox" testProj() + @testset "Nuclear Prox" testNuclear() + #@testset "LLR Prox" testLLR() + #@testset "LLR Prox Overlapping" #testLLROverlapping() + #@testset "LLR Prox 3D" testLLR_3D() + @testset "Prox Lambda Conversion" testConversion() end