Skip to content

Commit

Permalink
- bugfix: complex TV along 1 dimensions was wrongly assumed equivalen…
Browse files Browse the repository at this point in the history
…t to real TV along real and imaginary part

- bugfix: iterationsTV are passed from constructor to prox function
- speedup in the iterative prox
  • Loading branch information
JakobAsslaender committed Nov 25, 2023
1 parent da8483c commit 810d0c8
Showing 1 changed file with 38 additions and 46 deletions.
84 changes: 38 additions & 46 deletions src/proximalMaps/ProxTV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained
Total Variation Image Denoising
and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
Reference for the FGP algorithm:
A. Beck and T. Teboulle, "Fast Gradient-Based Algorithms for Constrained
Total Variation Image Denoising
and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
# Arguments
* `λ::T` - regularization parameter
Expand All @@ -26,16 +21,16 @@ and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009
* `dims` - Dimension to perform the TV along. If `Integer`, the Condat algorithm is called, and the FDG algorithm otherwise.
* `iterationsTV=20` - number of FGP iterations
"""
struct TVRegularization{T, N, TI} <: AbstractParameterizedRegularization{T} where {N, TI<:Integer}
struct TVRegularization{T,N,TI} <: AbstractParameterizedRegularization{T} where {N,TI<:Integer}
λ::T
dims
shape::NTuple{N,TI}
iterationsTV::Int64
end
TVRegularization(λ; shape = (0,), dims = 1:length(shape), iterationsTV = 10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV)
TVRegularization(λ; shape=(0,), dims=1:length(shape), iterationsTV=10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV)


mutable struct TVParams{Tc, matT}
mutable struct TVParams{Tc,matT}
pq::Vector{Tc}
rs::Vector{Tc}
pqOld::Vector{Tc}
Expand All @@ -44,15 +39,15 @@ mutable struct TVParams{Tc, matT}
end

function TVParams(shape, T::Type=Float64; dims=1:length(shape))
return TVParams(Vector{T}(undef,prod(shape)); shape=shape, dims=dims)
return TVParams(Vector{T}(undef, prod(shape)); shape=shape, dims=dims)

Check warning on line 42 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L42

Added line #L42 was not covered by tests
end

function TVParams(x::AbstractVector{Tc}; shape, dims=1:length(shape)) where {Tc}
= GradientOp(Tc; shape, dims)

# allocate storage
xTmp = similar(x)
pq = similar(x, size(∇,1))
pq = similar(x, size(∇, 1))
rs = similar(pq)
pqOld = similar(pq)

Expand All @@ -66,64 +61,59 @@ end
Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise.
"""
prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = proxTV!(x, λ, shape = reg.shape, dims = reg.dims)
prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV)

function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims
return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims
end

function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real}
function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T<:Real}

Check warning on line 70 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L70

Added line #L70 was not covered by tests
x_ = reshape(x, shape)
i = CartesianIndices((ones(Int, dims - 1)..., 0:shape[dims]-1, ones(Int, length(shape) - dims)...))

Check warning on line 72 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L72

Added line #L72 was not covered by tests

if dims == 1
for j CartesianIndices(shape[dims+1:end])
@views @inbounds tv_denoise_1d_condat!(x_[:,j], shape[dims], λ)
end
elseif dims == length(shape)
for i CartesianIndices(shape[1:dims-1])
@views @inbounds tv_denoise_1d_condat!(x_[i,:], shape[dims], λ)
end
else
for j CartesianIndices(shape[dims+1:end]), i CartesianIndices(shape[1:dims-1])
@views @inbounds tv_denoise_1d_condat!(x_[i,:,j], shape[dims], λ)
end
Threads.@threads for j CartesianIndices((shape[1:dims-1]..., 1, shape[dims+1:end]...))
@views @inbounds tv_denoise_1d_condat!(x_[j.+i], shape[dims], λ)

Check warning on line 75 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L74-L75

Added lines #L74 - L75 were not covered by tests
end
return x
end

# reinterpret complex-valued vector as 2xN matrix and change the shape etc accordingly
function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims::Integer; kwargs...) where {T <: Real, Tc <: Complex{T}}
proxTV!(vec(reinterpret(reshape, T, x)), λ, shape=(2, shape...), dims=(dims+1) )
return x
end

function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=20, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T <: Real,Tc <: Union{T, Complex{T}}}
return proxTV!(x,λ,tvpar; iterationsTV=iterationsTV)
function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=10, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}}
return proxTV!(x, λ, tvpar; iterationsTV=iterationsTV)
end

function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=20, kwargs...) where {T <: Real, Tc <: Union{T, Complex{T}}}
function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=10, kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}}
@assert length(p.xTmp) == length(x)
# initialize dual variables
p.xTmp .= 0
p.pq .= 0
p.rs .= 0
p.xTmp .= 0
p.pq .= 0
p.rs .= 0
p.pqOld .= 0

t = one(T)
for _ = 1:iterationsTV
p.pqOld .= p.pq
pqTmp = p.pqOld
p.pqOld = p.pq
p.pq = p.rs

# gradient projection step for dual variables
p.xTmp .= x
Threads.@threads for i eachindex(p.xTmp, x)
@inbounds p.xTmp[i] = x[i]
end

Check warning on line 101 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L101

Added line #L101 was not covered by tests
mul!(p.xTmp, transpose(p.∇), p.rs, -λ, 1) # xtmp = x-λ*transpose(∇)*rs
p.pq .= p.rs
mul!(p.pq, p.∇, p.xTmp, 1/(8λ), 1) # rs = ∇*xTmp/(8λ)
mul!(p.pq, p.∇, p.xTmp, 1 / (8λ), 1) # rs = ∇*xTmp/(8λ)

restrictMagnitude!(p.pq)

# form linear combination of old and new estimates
tOld = t
t = (1 + sqrt(1+4*tOld^2)) / 2
t2 = ((tOld-1)/t)
p.rs .= (1+t2) .* p.pq .- t2 .* p.pqOld
t = (1 + sqrt(1 + 4 * tOld^2)) / 2
t2 = ((tOld - 1) / t)
t3 = 1 + t2

p.rs = pqTmp
Threads.@threads for i eachindex(p.rs, p.pq, p.pqOld)
@inbounds p.rs[i] = t3 * p.pq[i] - t2 * p.pqOld[i]
end

Check warning on line 116 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L116

Added line #L116 was not covered by tests
end

mul!(x, transpose(p.∇), p.pq, -λ, one(Tc)) # x .-= λ*transpose(∇)*pq
Expand All @@ -132,15 +122,17 @@ end

# restrict x to a number smaller then one
function restrictMagnitude!(x)
x ./= max.(1, abs.(x))
Threads.@threads for i in eachindex(x)
@inbounds x[i] /= max(1, abs(x[i]))
end

Check warning on line 127 in src/proximalMaps/ProxTV.jl

View check run for this annotation

Codecov / codecov/patch

src/proximalMaps/ProxTV.jl#L127

Added line #L127 was not covered by tests
end

"""
norm(reg::TVRegularization, x, λ)
returns the value of the TV-regularization term.
"""
function norm(reg::TVRegularization, x::Vector{Tc}::T) where {T <: Real, Tc <: Union{T, Complex{T}}}
function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}}
= GradientOp(Tc; shape=reg.shape, dims=reg.dims)
return λ * norm(∇*x, 1)
return λ * norm(∇ * x, 1)
end

0 comments on commit 810d0c8

Please sign in to comment.