Skip to content

Commit

Permalink
Merge pull request #82 from SebastianFlassbeck/master
Browse files Browse the repository at this point in the history
Allow the use of fully overlapping blocks for LLR
  • Loading branch information
nHackel authored Jun 22, 2024
2 parents fd4de54 + e163400 commit 724b93e
Showing 1 changed file with 26 additions and 38 deletions.
64 changes: 26 additions & 38 deletions src/proximalMaps/ProxLLR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,37 @@ Regularization term implementing the proximal map for locally low rank (LLR) reg
* `λ` - regularization paramter
# Keywords
* `shape::Tuple{Int}=[]` - dimensions of the image
* `blockSize::Tuple{Int}=[2;2]` - size of patches to perform singular value thresholding on
* `randshift::Bool=true` - randomly shifts the patches to ensure translation invariance
* `shape::Tuple{Int}` - dimensions of the image
* `blockSize::Tuple{Int}=(2,2)` - size of patches to perform singular value thresholding on
* `randshift::Bool=true` - randomly shifts the patches to ensure translation invariance
* `fullyOverlapping::Bool=false` - choose between fully overlapping block or non-overlapping blocks
"""
struct LLRRegularization{T, N, TI} <: AbstractParameterizedRegularization{T} where {N, TI<:Integer}
λ::T
shape::NTuple{N,TI}
blockSize::NTuple{N,TI}
randshift::Bool
fullyOverlapping::Bool
L::Int64
end
LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), randshift::Bool = true, L::Int64 = 1, kargs...) where {N,TI<:Integer} =
LLRRegularization(λ, shape, blockSize, randshift, L)
LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), randshift::Bool = true, fullyOverlapping::Bool = false, L::Int64 = 1, kargs...) where {N,TI<:Integer} =
LLRRegularization(λ, shape, blockSize, randshift, fullyOverlapping, L)

"""
prox!(reg::LLRRegularization, x, λ)
performs the proximal map for LLR regularization using singular-value-thresholding
"""
function prox!(reg::LLRRegularization{TR, N, TI}, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {TR, N, TI, T <: Real}
Tc = eltype(x)
function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}}
reg.fullyOverlapping ? proxLLROverlapping!(reg, x, λ) : proxLLRNonOverlapping!(reg, x, λ)
end

"""
proxLLRNonOverlapping!(reg::LLRRegularization, x, λ)
performs the proximal map for LLR regularization using singular-value-thresholding on non-overlapping blocks
"""
function proxLLRNonOverlapping!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}}
shape = reg.shape
blockSize = reg.blockSize
randshift = reg.randshift
Expand Down Expand Up @@ -91,7 +101,7 @@ end
"""
norm(reg::LLRRegularization, x, λ)
returns the value of the LLR-regularization term.
returns the value of the LLR-regularization term. The norm is only implemented for 2D, non-fully overlapping blocks.
"""
function norm(reg::LLRRegularization, x::Union{AbstractArray{T}, AbstractArray{Complex{T}}}, λ::T) where {T <: Real}
shape = reg.shape
Expand Down Expand Up @@ -154,23 +164,14 @@ end


"""
proxLLROverlapping!(x::Vector{T}, λ=1e-6; kargs...) where T
proximal map for LLR regularization with fully overlapping blocks
proxLLROverlapping!(reg::LLRRegularization, x, λ)
# Arguments
* `x::Vector{T}` - Vector to apply proximal map to
* `λ` - regularization parameter
* `shape::Tuple{Int}=[]` - dimensions of the image
* `blockSize::NTuple{Int}=ntuple(_ -> 2, N)` - size of patches to perform singular value thresholding on
performs the proximal map for LLR regularization using singular-value-thresholding with fully overlapping blocks
"""
function proxLLROverlapping!(
x::Vector{T},
λ;
shape::NTuple{N,TI},
blockSize::NTuple{N,TI} = ntuple(_ -> 2, N),
) where {T,N,TI<:Integer}

function proxLLROverlapping!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}}
shape = reg.shape
blockSize = reg.blockSize

x = reshape(x, tuple(shape..., length(x) ÷ prod(shape)))

block_idx = CartesianIndices(blockSize)
Expand All @@ -179,7 +180,7 @@ function proxLLROverlapping!(
ext = mod.(shape, blockSize)
pad = mod.(blockSize .- ext, blockSize)
if any(pad .!= 0)
xp = zeros(T, (shape .+ pad)..., K)
xp = zeros(Tc, (shape .+ pad)..., K)
xp[CartesianIndices(x)] .= x
else
xp = copy(x)
Expand All @@ -190,23 +191,10 @@ function proxLLROverlapping!(
bthreads = BLAS.get_num_threads()
try
BLAS.set_num_threads(1)
xᴸᴸᴿ = [Array{T}(undef, prod(blockSize), K) for _ = 1:Threads.nthreads()]
for is block_idx
shift_idx = (Tuple(is)..., 0)
xs = circshift(xp, shift_idx)

@floop for i CartesianIndices(StepRange.(TI(0), blockSize, shape .- 1))
@views xᴸᴸᴿ[Threads.threadid()] .= reshape(xs[i.+block_idx, :], :, K)

ub = sqrt(norm(xᴸᴸᴿ[Threads.threadid()]' * xᴸᴸᴿ[Threads.threadid()], Inf)) #upper bound on singular values given by matrix infinity norm
if λ >= ub #save time by skipping the SVT as recommended by Ong/Lustig, IEEE 2016
xs[i.+block_idx, :] .= 0
else # threshold singular values
SVDec = svd!(xᴸᴸᴿ[Threads.threadid()])
prox!(L1Regularization, SVDec.S, λ)
xs[i.+block_idx, :] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :)
end
end
proxLLRNonOverlapping!(reg, xs, λ)
x .+= circshift(xs, -1 .* shift_idx)[CartesianIndices(x)]
end
finally
Expand Down

0 comments on commit 724b93e

Please sign in to comment.