Skip to content

Commit

Permalink
Fix proximal maps for GPU (except LLR)
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed May 28, 2024
1 parent 4866734 commit c8259a7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 6 deletions.
11 changes: 11 additions & 0 deletions ext/RegularizedLeastSquaresGPUArraysExt/ProxL21.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function RegularizedLeastSquares.proxL21!(x::vecT, λ::T, slices::Int64) where {T, vecT <: Union{AbstractGPUVector{T}, AbstractGPUVector{Complex{T}}}}
sliceLength = div(length(x),slices)
groupNorm = [norm(x[i:sliceLength:end]) for i=1:sliceLength]

gpu_call(x, λ, groupNorm, sliceLength) do ctx, x_, λ_, groupNorm_, sliceLength_
i = @linearidx(x_)
@inbounds x_[i] = x_[i]*max( (groupNorm_[mod1(i,sliceLength_)]-λ_)/groupNorm_[mod1(i,sliceLength_)],0)
return nothing
end
return x
end
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module RegularizedLeastSquaresGPUArraysExt

using RegularizedLeastSquares, GPUArrays
using RegularizedLeastSquares, RegularizedLeastSquares.LinearAlgebra, GPUArrays

include("Utils.jl")
include("ProxTV.jl")
include("ProxL21.jl")

end
2 changes: 1 addition & 1 deletion src/proximalMaps/ProxNuclear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ performs singular value soft-thresholding - i.e. the proximal map for the nuclea
function prox!(reg::NuclearRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
U,S,V = svd(reshape(x, reg.svtShape))
prox!(L1Regularization, S, λ)
copyto!(x, vec(U*Matrix(Diagonal(S))*V'))
copyto!(x, vec(U*Diagonal(S)*V'))
return x
end

Expand Down
4 changes: 2 additions & 2 deletions src/proximalMaps/ProxProj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ struct ProjectionRegularization <: AbstractProjectionRegularization
end
ProjectionRegularization(; projFunc::Function=x->x, kargs...) = ProjectionRegularization(projFunc)

function prox!(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}}
function prox!(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}}
copyto!(x, reg.projFunc(x))
return x
end

function norm(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}}
function norm(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}}
y = copy(x)
copyto!(y, prox!(reg, y))
if y != x
Expand Down
4 changes: 2 additions & 2 deletions test/testProxMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ function testConversion()
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
end skip = in(prox, [LLRRegularization])
@test try norm(prox, zeros(xType, 10), lambdaType(0.0); shape=(2, 5), svtShape=(2, 5))
true
catch e
false
end skip = in(prox, [LLRRegularization, NuclearRegularization])
end skip = in(prox, [LLRRegularization])
end
end
end
Expand Down

0 comments on commit c8259a7

Please sign in to comment.