diff --git a/ext/RegularizedLeastSquaresGPUArraysExt/ProxL21.jl b/ext/RegularizedLeastSquaresGPUArraysExt/ProxL21.jl new file mode 100644 index 00000000..15759ced --- /dev/null +++ b/ext/RegularizedLeastSquaresGPUArraysExt/ProxL21.jl @@ -0,0 +1,11 @@ +function RegularizedLeastSquares.proxL21!(x::vecT, λ::T, slices::Int64) where {T, vecT <: Union{AbstractGPUVector{T}, AbstractGPUVector{Complex{T}}}} + sliceLength = div(length(x),slices) + groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength] + + gpu_call(x, λ, groupNorm, sliceLength) do ctx, x_, λ_, groupNorm_, sliceLength_ + i = @linearidx(x_) + @inbounds x_[i] = x_[i]*max( (groupNorm_[mod1(i,sliceLength_)]-λ_)/groupNorm_[mod1(i,sliceLength_)],0) + return nothing + end + return x +end \ No newline at end of file diff --git a/ext/RegularizedLeastSquaresGPUArraysExt/RegularizedLeastSquaresGPUArraysExt.jl b/ext/RegularizedLeastSquaresGPUArraysExt/RegularizedLeastSquaresGPUArraysExt.jl index 9cffd2ad..83581153 100644 --- a/ext/RegularizedLeastSquaresGPUArraysExt/RegularizedLeastSquaresGPUArraysExt.jl +++ b/ext/RegularizedLeastSquaresGPUArraysExt/RegularizedLeastSquaresGPUArraysExt.jl @@ -1,8 +1,9 @@ module RegularizedLeastSquaresGPUArraysExt -using RegularizedLeastSquares, GPUArrays +using RegularizedLeastSquares, RegularizedLeastSquares.LinearAlgebra, GPUArrays include("Utils.jl") include("ProxTV.jl") +include("ProxL21.jl") end \ No newline at end of file diff --git a/src/proximalMaps/ProxNuclear.jl b/src/proximalMaps/ProxNuclear.jl index 46ac0e6d..f8e7c14d 100644 --- a/src/proximalMaps/ProxNuclear.jl +++ b/src/proximalMaps/ProxNuclear.jl @@ -26,7 +26,7 @@ performs singular value soft-thresholding - i.e. the proximal map for the nuclea function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real} U,S,V = svd(reshape(x, reg.svtShape)) prox!(L1Regularization, S, λ) - copyto!(x, vec(U*Matrix(Diagonal(S))*V')) + copyto!(x, vec(U*Diagonal(S)*V')) return x end diff --git a/src/proximalMaps/ProxProj.jl b/src/proximalMaps/ProxProj.jl index df974382..6a6bc99b 100644 --- a/src/proximalMaps/ProxProj.jl +++ b/src/proximalMaps/ProxProj.jl @@ -5,12 +5,12 @@ struct ProjectionRegularization <: AbstractProjectionRegularization end ProjectionRegularization(; projFunc::Function=x->x, kargs...) = ProjectionRegularization(projFunc) -function prox!(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}} +function prox!(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}} copyto!(x, reg.projFunc(x)) return x end -function norm(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}} +function norm(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}} y = copy(x) copyto!(y, prox!(reg, y)) if y != x diff --git a/test/testProxMaps.jl b/test/testProxMaps.jl index eaf1522b..ced2f06d 100644 --- a/test/testProxMaps.jl +++ b/test/testProxMaps.jl @@ -285,12 +285,12 @@ function testConversion() true catch e false - end skip = in(prox, [LLRRegularization, NuclearRegularization]) + end skip = in(prox, [LLRRegularization]) @test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5)) true catch e false - end skip = in(prox, [LLRRegularization, NuclearRegularization]) + end skip = in(prox, [LLRRegularization]) end end end