Skip to content

Commit

Permalink
clean up FISTA
Browse files Browse the repository at this point in the history
  • Loading branch information
migrosser committed Aug 15, 2019
1 parent 1d81120 commit 6c3ab0b
Showing 1 changed file with 47 additions and 64 deletions.
111 changes: 47 additions & 64 deletions src/FISTA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,14 @@ Solve the problem: X = arg min_x 1/2*|| Ax-b||² + λ*g(X) where:
* `A` - system matrix
* `b::Vector{T}` - data vector (right-hand side)
* `reg::Regularization` - regularization object
* (`AHA=nothing`) - normal operator adjoint(A)*A
* (`startVector=nothing`) - start vector
* (`iterations::Int64=50`) - maximum number of iterations
* (`ρ::Float64=1.0`) - step size for gradient step
* (`t::Float64=1.0`) - step size for predictor-corrector step
* (`relTol::Float64=1.e-4`) - relative tolerance for stopping criterion
* (`solverInfo = nothing`) - `solverInfo` object used to store convergence metrics
"""
function fista(A,b::Vector{T}, reg::Regularization; AHA=nothing, kargs...) where T
if AHA==nothing
return fista1(A,b,reg;kargs...)
else
return fista2(A,b,reg;kargs...)
end
end

"""
This funtion implements the fista algorithm.
Solve the problem: X = arg min_x 1/2*|| Ax-b||² + λ*g(X) where:
x: variable (vector)
b: measured data
A: a general linear operator
g(X): a convex but not necessarily a smooth function
"""
function fista1(A, b::Vector{T}, reg::Regularization
function fista(A, b::Vector{T}, reg::Regularization
; startVector=nothing
, iterations::Int64=50
, ρ::Float64=1.0
Expand Down Expand Up @@ -128,49 +111,49 @@ end

# alternative implementation allowing for an optimized AHA
# does not contain a stopping condition
function fista2(A, b::Vector{T}, reg::Regularization
; AHA=nothing
, startVector=nothing
, iterations::Int64=50
, ρ::Float64=1.0
, t::Float64=1.0
, solverInfo = nothing
, kargs...) where T

if startVector == nothing
x = A' * b
else
x = startVector
end

# if AHA!=nothing
op = AHA
# else
# op = A'*A
# end

β = A'*b

xᵒˡᵈ = copy(x)

solverInfo != nothing && storeInfo(solverInfo,A,b,x;xᵒˡᵈ=xᵒˡᵈ,reg=[reg])

costFunc = 0.5*norm(res)^2+norm(reg,x)

for l=1:iterations
xᵒˡᵈ[:] = x[:]

x[:] = x[:] - ρ*(op*x-β)

reg.prox!(x, ρ*reg.λ)

tᵒˡᵈ = t

t = (1. + sqrt(1. + 4. * tᵒˡᵈ^2)) / 2.
x[:] = x + (tᵒˡᵈ-1)/t*(x-xᵒˡᵈ)

solverInfo != nothing && storeInfo(solverInfo,A,b,x;xᵒˡᵈ=xᵒˡᵈ,reg=[reg])
end

return x
end
# function fista2(A, b::Vector{T}, reg::Regularization
# ; AHA=nothing
# , startVector=nothing
# , iterations::Int64=50
# , ρ::Float64=1.0
# , t::Float64=1.0
# , solverInfo = nothing
# , kargs...) where T
#
# if startVector == nothing
# x = A' * b
# else
# x = startVector
# end
#
# # if AHA!=nothing
# op = AHA
# # else
# # op = A'*A
# # end
#
# β = A'*b
#
# xᵒˡᵈ = copy(x)
#
# solverInfo != nothing && storeInfo(solverInfo,A,b,x;xᵒˡᵈ=xᵒˡᵈ,reg=[reg])
#
# costFunc = 0.5*norm(res)^2+norm(reg,x)
#
# for l=1:iterations
# xᵒˡᵈ[:] = x[:]
#
# x[:] = x[:] - ρ*(op*x-β)
#
# reg.prox!(x, ρ*reg.λ)
#
# tᵒˡᵈ = t
#
# t = (1. + sqrt(1. + 4. * tᵒˡᵈ^2)) / 2.
# x[:] = x + (tᵒˡᵈ-1)/t*(x-xᵒˡᵈ)
#
# solverInfo != nothing && storeInfo(solverInfo,A,b,x;xᵒˡᵈ=xᵒˡᵈ,reg=[reg])
# end
#
# return x
# end

0 comments on commit 6c3ab0b

Please sign in to comment.