Skip to content

Commit

Permalink
Merge pull request #62 from JuliaImageRecon/FixTV
Browse files Browse the repository at this point in the history
Bug fixes and speedup in the TV
  • Loading branch information
tknopp authored Nov 25, 2023
2 parents da8483c + 810d0c8 commit b28532e
Showing 1 changed file with 38 additions and 46 deletions.
84 changes: 38 additions & 46 deletions src/proximalMaps/ProxTV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained
Total Variation Image Denoising
and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
Reference for the FGP algorithm:
A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained
Total Variation Image Denoising
and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
# Arguments
* `λ::T` - regularization parameter
Expand All @@ -26,16 +21,16 @@ and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
* `dims` - Dimension to perform the TV along. If `Integer`, the Condat algorithm is called, and the FDG algorithm otherwise.
* `iterationsTV=20` - number of FGP iterations
"""
struct TVRegularization{T, N, TI} <: AbstractParameterizedRegularization{T} where {N, TI<:Integer}
struct TVRegularization{T,N,TI} <: AbstractParameterizedRegularization{T} where {N,TI<:Integer}
λ::T
dims
shape::NTuple{N,TI}
iterationsTV::Int64
end
TVRegularization(λ; shape = (0,), dims = 1:length(shape), iterationsTV = 10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV)
TVRegularization(λ; shape=(0,), dims=1:length(shape), iterationsTV=10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV)


mutable struct TVParams{Tc, matT}
mutable struct TVParams{Tc,matT}
pq::Vector{Tc}
rs::Vector{Tc}
pqOld::Vector{Tc}
Expand All @@ -44,15 +39,15 @@ mutable struct TVParams{Tc, matT}
end

function TVParams(shape, T::Type=Float64; dims=1:length(shape))
return TVParams(Vector{T}(undef,prod(shape)); shape=shape, dims=dims)
return TVParams(Vector{T}(undef, prod(shape)); shape=shape, dims=dims)
end

function TVParams(x::AbstractVector{Tc}; shape, dims=1:length(shape)) where {Tc}
= GradientOp(Tc; shape, dims)

# allocate storage
xTmp = similar(x)
pq = similar(x, size(∇,1))
pq = similar(x, size(∇, 1))
rs = similar(pq)
pqOld = similar(pq)

Expand All @@ -66,64 +61,59 @@ end
Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise.
"""
prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = proxTV!(x, λ, shape = reg.shape, dims = reg.dims)
prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)

function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims
return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims
end

function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real}
function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T<:Real}
x_ = reshape(x, shape)
i = CartesianIndices((ones(Int, dims - 1)..., 0:shape[dims]-1, ones(Int, length(shape) - dims)...))

if dims == 1
for j CartesianIndices(shape[dims+1:end])
@views @inbounds tv_denoise_1d_condat!(x_[:,j], shape[dims], λ)
end
elseif dims == length(shape)
for i CartesianIndices(shape[1:dims-1])
@views @inbounds tv_denoise_1d_condat!(x_[i,:], shape[dims], λ)
end
else
for j CartesianIndices(shape[dims+1:end]), i CartesianIndices(shape[1:dims-1])
@views @inbounds tv_denoise_1d_condat!(x_[i,:,j], shape[dims], λ)
end
Threads.@threads for j CartesianIndices((shape[1:dims-1]..., 1, shape[dims+1:end]...))
@views @inbounds tv_denoise_1d_condat!(x_[j.+i], shape[dims], λ)
end
return x
end

# reinterpret complex-valued vector as 2xN matrix and change the shape etc accordingly
function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real, Tc <: Complex{T}}
proxTV!(vec(reinterpret(reshape, T, x)), λ, shape=(2, shape...), dims=(dims+1) )
return x
end

function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=20, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T <: Real,Tc <: Union{T, Complex{T}}}
return proxTV!(x,λ,tvpar; iterationsTV=iterationsTV)
function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=10, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}}
return proxTV!(x, λ, tvpar; iterationsTV=iterationsTV)
end

function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=20, kwargs...) where {T <: Real, Tc <: Union{T, Complex{T}}}
function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=10, kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}}
@assert length(p.xTmp) == length(x)
# initialize dual variables
p.xTmp .= 0
p.pq .= 0
p.rs .= 0
p.xTmp .= 0
p.pq .= 0
p.rs .= 0
p.pqOld .= 0

t = one(T)
for _ = 1:iterationsTV
p.pqOld .= p.pq
pqTmp = p.pqOld
p.pqOld = p.pq
p.pq = p.rs

# gradient projection step for dual variables
p.xTmp .= x
Threads.@threads for i eachindex(p.xTmp, x)
@inbounds p.xTmp[i] = x[i]
end
mul!(p.xTmp, transpose(p.∇), p.rs, -λ, 1) # xtmp = x-λ*transpose(∇)*rs
p.pq .= p.rs
mul!(p.pq, p.∇, p.xTmp, 1/(8λ), 1) # rs = ∇*xTmp/(8λ)
mul!(p.pq, p.∇, p.xTmp, 1 / (8λ), 1) # rs = ∇*xTmp/(8λ)

restrictMagnitude!(p.pq)

# form linear combination of old and new estimates
tOld = t
t = (1 + sqrt(1+4*tOld^2)) / 2
t2 = ((tOld-1)/t)
p.rs .= (1+t2) .* p.pq .- t2 .* p.pqOld
t = (1 + sqrt(1 + 4 * tOld^2)) / 2
t2 = ((tOld - 1) / t)
t3 = 1 + t2

p.rs = pqTmp
Threads.@threads for i eachindex(p.rs, p.pq, p.pqOld)
@inbounds p.rs[i] = t3 * p.pq[i] - t2 * p.pqOld[i]
end
end

mul!(x, transpose(p.∇), p.pq, -λ, one(Tc)) # x .-= λ*transpose(∇)*pq
Expand All @@ -132,15 +122,17 @@ end

# restrict x to a number smaller then one
function restrictMagnitude!(x)
x ./= max.(1, abs.(x))
Threads.@threads for i in eachindex(x)
@inbounds x[i] /= max(1, abs(x[i]))
end
end

"""
norm(reg::TVRegularization, x, λ)
returns the value of the TV-regularization term.
"""
function norm(reg::TVRegularization, x::Vector{Tc}::T) where {T <: Real, Tc <: Union{T, Complex{T}}}
function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}}
= GradientOp(Tc; shape=reg.shape, dims=reg.dims)
return λ * norm(∇*x, 1)
return λ * norm(∇ * x, 1)
end

2 comments on commit b28532e

@nHackel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/95994

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.11.2 -m "<description of version>" b28532eb2fa28aedda2319d3d9f22bc283755517
git push origin v0.11.2

Please sign in to comment.