From 810d0c831cbc4b1f1d64b5f654e9d1b373ced4ca Mon Sep 17 00:00:00 2001 From: Jakob Asslaender Date: Sat, 25 Nov 2023 12:42:58 -0500 Subject: [PATCH] - bugfix: complex TV along 1 dimensions was wrongly assumed equivalent to real TV along real and imaginary part - bugfix: iterationsTV are passed from constructor to prox function - speedup in the iterative prox --- src/proximalMaps/ProxTV.jl | 84 +++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/src/proximalMaps/ProxTV.jl b/src/proximalMaps/ProxTV.jl index 8c4ebaef..fb78c274 100644 --- a/src/proximalMaps/ProxTV.jl +++ b/src/proximalMaps/ProxTV.jl @@ -13,11 +13,6 @@ A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009 -Reference for the FGP algorithm: -A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained -Total Variation Image Denoising -and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009 - # Arguments * `λ::T` - regularization parameter @@ -26,16 +21,16 @@ and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009 * `dims` - Dimension to perform the TV along. If `Integer`, the Condat algorithm is called, and the FDG algorithm otherwise. * `iterationsTV=20` - number of FGP iterations """ -struct TVRegularization{T, N, TI} <: AbstractParameterizedRegularization{T} where {N, TI<:Integer} +struct TVRegularization{T,N,TI} <: AbstractParameterizedRegularization{T} where {N,TI<:Integer} λ::T dims shape::NTuple{N,TI} iterationsTV::Int64 end -TVRegularization(λ; shape = (0,), dims = 1:length(shape), iterationsTV = 10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV) +TVRegularization(λ; shape=(0,), dims=1:length(shape), iterationsTV=10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV) -mutable struct TVParams{Tc, matT} +mutable struct TVParams{Tc,matT} pq::Vector{Tc} rs::Vector{Tc} pqOld::Vector{Tc} @@ -44,7 +39,7 @@ mutable struct TVParams{Tc, matT} end function TVParams(shape, T::Type=Float64; dims=1:length(shape)) - return TVParams(Vector{T}(undef,prod(shape)); shape=shape, dims=dims) + return TVParams(Vector{T}(undef, prod(shape)); shape=shape, dims=dims) end function TVParams(x::AbstractVector{Tc}; shape, dims=1:length(shape)) where {Tc} @@ -52,7 +47,7 @@ function TVParams(x::AbstractVector{Tc}; shape, dims=1:length(shape)) where {Tc} # allocate storage xTmp = similar(x) - pq = similar(x, size(∇,1)) + pq = similar(x, size(∇, 1)) rs = similar(pq) pqOld = similar(pq) @@ -66,64 +61,59 @@ end Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise. """ -prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = proxTV!(x, λ, shape = reg.shape, dims = reg.dims) +prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV) function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims end -function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real} +function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T<:Real} x_ = reshape(x, shape) + i = CartesianIndices((ones(Int, dims - 1)..., 0:shape[dims]-1, ones(Int, length(shape) - dims)...)) - if dims == 1 - for j ∈ CartesianIndices(shape[dims+1:end]) - @views @inbounds tv_denoise_1d_condat!(x_[:,j], shape[dims], λ) - end - elseif dims == length(shape) - for i ∈ CartesianIndices(shape[1:dims-1]) - @views @inbounds tv_denoise_1d_condat!(x_[i,:], shape[dims], λ) - end - else - for j ∈ CartesianIndices(shape[dims+1:end]), i ∈ CartesianIndices(shape[1:dims-1]) - @views @inbounds tv_denoise_1d_condat!(x_[i,:,j], shape[dims], λ) - end + Threads.@threads for j ∈ CartesianIndices((shape[1:dims-1]..., 1, shape[dims+1:end]...)) + @views @inbounds tv_denoise_1d_condat!(x_[j.+i], shape[dims], λ) end return x end -# reinterpret complex-valued vector as 2xN matrix and change the shape etc accordingly -function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real, Tc <: Complex{T}} - proxTV!(vec(reinterpret(reshape, T, x)), λ, shape=(2, shape...), dims=(dims+1) ) - return x -end - -function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=20, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T <: Real,Tc <: Union{T, Complex{T}}} - return proxTV!(x,λ,tvpar; iterationsTV=iterationsTV) +function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=10, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}} + return proxTV!(x, λ, tvpar; iterationsTV=iterationsTV) end -function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=20, kwargs...) where {T <: Real, Tc <: Union{T, Complex{T}}} +function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=10, kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}} + @assert length(p.xTmp) == length(x) # initialize dual variables - p.xTmp .= 0 - p.pq .= 0 - p.rs .= 0 + p.xTmp .= 0 + p.pq .= 0 + p.rs .= 0 p.pqOld .= 0 t = one(T) for _ = 1:iterationsTV - p.pqOld .= p.pq + pqTmp = p.pqOld + p.pqOld = p.pq + p.pq = p.rs # gradient projection step for dual variables - p.xTmp .= x + Threads.@threads for i ∈ eachindex(p.xTmp, x) + @inbounds p.xTmp[i] = x[i] + end mul!(p.xTmp, transpose(p.∇), p.rs, -λ, 1) # xtmp = x-λ*transpose(∇)*rs - p.pq .= p.rs - mul!(p.pq, p.∇, p.xTmp, 1/(8λ), 1) # rs = ∇*xTmp/(8λ) + mul!(p.pq, p.∇, p.xTmp, 1 / (8λ), 1) # rs = ∇*xTmp/(8λ) + restrictMagnitude!(p.pq) # form linear combination of old and new estimates tOld = t - t = (1 + sqrt(1+4*tOld^2)) / 2 - t2 = ((tOld-1)/t) - p.rs .= (1+t2) .* p.pq .- t2 .* p.pqOld + t = (1 + sqrt(1 + 4 * tOld^2)) / 2 + t2 = ((tOld - 1) / t) + t3 = 1 + t2 + + p.rs = pqTmp + Threads.@threads for i ∈ eachindex(p.rs, p.pq, p.pqOld) + @inbounds p.rs[i] = t3 * p.pq[i] - t2 * p.pqOld[i] + end end mul!(x, transpose(p.∇), p.pq, -λ, one(Tc)) # x .-= λ*transpose(∇)*pq @@ -132,7 +122,9 @@ end # restrict x to a number smaller then one function restrictMagnitude!(x) - x ./= max.(1, abs.(x)) + Threads.@threads for i in eachindex(x) + @inbounds x[i] /= max(1, abs(x[i])) + end end """ @@ -140,7 +132,7 @@ end returns the value of the TV-regularization term. """ -function norm(reg::TVRegularization, x::Vector{Tc},λ::T) where {T <: Real, Tc <: Union{T, Complex{T}}} +function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}} ∇ = GradientOp(Tc; shape=reg.shape, dims=reg.dims) - return λ * norm(∇*x, 1) + return λ * norm(∇ * x, 1) end