Skip to content

Commit

Permalink
Merge pull request #81 from JuliaImageRecon/gpuStates
Browse files Browse the repository at this point in the history
Add GPU support based on (exchangable) structs for solver state
  • Loading branch information
nHackel authored Jul 4, 2024
2 parents af272fb + c8fabac commit 1741afd
Show file tree
Hide file tree
Showing 35 changed files with 1,157 additions and 675 deletions.
37 changes: 37 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
steps:
- label: "Nvidia GPUs -- RegularizedLeastSquares.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
agents:
queue: "juliagpu"
cuda: "*"
command: |
julia --color=yes --project -e '
using Pkg
Pkg.add("TestEnv")
using TestEnv
TestEnv.activate();
Pkg.add("CUDA")
Pkg.instantiate()
include("test/gpu/cuda.jl")'
timeout_in_minutes: 30

- label: "AMD GPUs -- RegularizedLeastSquares.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
command: |
julia --color=yes --project -e '
using Pkg
Pkg.add("TestEnv")
using TestEnv
TestEnv.activate();
Pkg.add("AMDGPU")
Pkg.instantiate()
include("test/gpu/rocm.jl")'
timeout_in_minutes: 30
15 changes: 12 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,29 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[weakdeps]
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"

[compat]
IterativeSolvers = "0.9"
julia = "1.9"
StatsBase = "0.33, 0.34"
VectorizationBase = "0.19, 0.21"
FFTW = "1.0"
LinearOperatorCollection = "1.2"
LinearOperators = "2.3.3"
FLoops = "0.2"
GPUArrays = "8, 9, 10"
JLArrays = "0.1.2"
LinearOperatorCollection = "2"
LinearOperators = "2.3.3"

[targets]
test = ["Test", "Random", "FFTW"]
test = ["Test", "Random", "FFTW", "JLArrays"]

[extensions]
RegularizedLeastSquaresGPUArraysExt = "GPUArrays"
15 changes: 15 additions & 0 deletions ext/RegularizedLeastSquaresGPUArraysExt/Kaczmarz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function RegularizedLeastSquares.iterate_row_index(solver::Kaczmarz, state::RegularizedLeastSquares.KaczmarzState{T, vecT}, A, row, index) where {T, vecT <: AbstractGPUArray}
state.τl = RegularizedLeastSquares.dot_with_matrix_row(A,state.x,row)
@allowscalar state.αl = solver.denom[index]*(state.u[row]-state.τl-state.ɛw*state.vl[row])
RegularizedLeastSquares.kaczmarz_update!(A,state.x,row,state.αl)
@allowscalar state.vl[row] += state.αl*state.ɛw
end

function RegularizedLeastSquares.kaczmarz_update!(A, x::vecT, row, beta) where {T, vecT <: AbstractGPUVector{T}}
x[:] .= x .+ beta * conj.(view(A, row, :))
end

function RegularizedLeastSquares.kaczmarz_update!(B::Transpose{T, S}, x::vecT, row, beta) where {T, S <: AbstractGPUArray{T}, vecT <: AbstractGPUVector{T}}
A = B.parent
x[:] .= x .+ beta * conj.(view(A, :, row))
end
11 changes: 11 additions & 0 deletions ext/RegularizedLeastSquaresGPUArraysExt/ProxL21.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function RegularizedLeastSquares.proxL21!(x::vecT, λ::T, slices::Int64) where {T, vecT <: Union{AbstractGPUVector{T}, AbstractGPUVector{Complex{T}}}}
sliceLength = div(length(x),slices)
groupNorm = copyto!(similar(x, Float32, sliceLength), [Float32(norm(x[i:sliceLength:end])) for i=1:sliceLength])

gpu_call(x, λ, groupNorm, sliceLength) do ctx, x_, λ_, groupNorm_, sliceLength_
i = @linearidx(x_)
@inbounds x_[i] = x_[i]*max( (groupNorm_[mod1(i,sliceLength_)]-λ_)/groupNorm_[mod1(i,sliceLength_)],0)
return nothing
end
return x
end
15 changes: 15 additions & 0 deletions ext/RegularizedLeastSquaresGPUArraysExt/ProxTV.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function RegularizedLeastSquares.tv_restrictMagnitude!(x::vecT) where {T, vecT <: AbstractGPUVector{T}}
gpu_call(x) do ctx, x_
i = @linearidx(x_)
@inbounds x_[i] /= max(1, abs(x_[i]))
return nothing
end
end

function RegularizedLeastSquares.tv_linearcomb!(rs::vecT, t3, pq::vecT, t2, pqOld::vecT) where {T, vecT <: AbstractGPUVector{T}}
gpu_call(rs, t3, pq, t2, pqOld) do ctx, rs_, t3_, pq_, t2_, pqOld_
i = @linearidx(rs_)
@inbounds rs_[i] = t3_ * pq_[i] - t2_ * pqOld_[i]
return nothing
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module RegularizedLeastSquaresGPUArraysExt

using RegularizedLeastSquares, RegularizedLeastSquares.LinearAlgebra, GPUArrays

include("Utils.jl")
include("ProxTV.jl")
include("ProxL21.jl")
include("Kaczmarz.jl")

end
47 changes: 47 additions & 0 deletions ext/RegularizedLeastSquaresGPUArraysExt/Utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
This function enforces the constraint of a real solution.
"""
function RegularizedLeastSquares.enfReal!(x::arrT) where {N, T<:Complex, arrGPUT <: AbstractGPUArray{T}, arrT <: Union{arrGPUT, SubArray{T, N, arrGPUT}}}
#Returns x as complex vector with imaginary part set to zero
gpu_call(x) do ctx, x_
i = @linearidx(x_)
@inbounds (x_[i] = complex(x_[i].re))
return nothing
end
end

"""
This function enforces the constraint of a real solution.
"""
RegularizedLeastSquares.enfReal!(x::arrT) where {N, T<:Real, arrGPUT <: AbstractGPUArray{T}, arrT <: Union{arrGPUT, SubArray{T, N, arrGPUT}}} = nothing

"""
This function enforces positivity constraints on its input.
"""
function RegularizedLeastSquares.enfPos!(x::arrT) where {N, T<:Complex, arrGPUT <: AbstractGPUArray{T}, arrT <: Union{arrGPUT, SubArray{T, N, arrGPUT}}}
#Return x as complex vector with negative parts projected onto 0
gpu_call(x) do ctx, x_
i = @linearidx(x_)
@inbounds (x_[i].re < 0) && (x_[i] = im*x_[i].im)
return nothing
end
end

"""
This function enforces positivity constraints on its input.
"""
function RegularizedLeastSquares.enfPos!(x::arrT) where {T<:Real, arrT <: AbstractGPUArray{T}}
#Return x as complex vector with negative parts projected onto 0
gpu_call(x) do ctx, x_
i = @linearidx(x_)
@inbounds (x_[i] < 0) && (x_[i] = zero(T))
return nothing
end
end

RegularizedLeastSquares.rownorm²(A::AbstractGPUMatrix,row::Int64) = sum(map(abs2, @view A[row, :]))
RegularizedLeastSquares.rownorm²(B::Transpose{T,S},row::Int64) where {T,S<:AbstractGPUArray} = sum(map(abs2, @view B.parent[:, row]))


RegularizedLeastSquares.dot_with_matrix_row(A::AbstractGPUMatrix{T}, x::AbstractGPUVector{T}, k::Int64) where {T} = reduce(+, x .* view(A, k, :))
RegularizedLeastSquares.dot_with_matrix_row(B::Transpose{T,S}, x::AbstractGPUVector{T}, k::Int64) where {T,S<:AbstractGPUArray} = reduce(+, x .* view(B.parent, :, k))
Loading

0 comments on commit 1741afd

Please sign in to comment.