Skip to content

Commit

Permalink
Merge pull request #37 from ptiede/ptiede-gpuexec
Browse files Browse the repository at this point in the history
Make ComradeBase gpu compatible
  • Loading branch information
ptiede authored Aug 25, 2024
2 parents 48798ff + 40496cc commit e65308d
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 108 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"

[extensions]
ComradeBaseAdaptExt = "Adapt"
ComradeBaseEnzymeExt = "Enzyme"
ComradeBaseKernelAbstractionsExt = "KernelAbstractions"
ComradeBaseOhMyThreadsExt = "OhMyThreads"

[compat]
DimensionalData = "0.27"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12"
EnzymeCore = "0.6, 0.7"
KernelAbstractions = "0.9"
OhMyThreads = "0.5, 0.6"
PolarizedTypes = "0.1"
PrecompileTools = "1"
Expand All @@ -35,15 +40,17 @@ StructArrays = "0.6"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "Enzyme", "FiniteDifferences", "JET", "OhMyThreads", "Pyehtim", "StaticArrays", "StructArrays", "Test"]
test = ["Adapt", "ChainRulesTestUtils", "Enzyme", "FiniteDifferences", "KernelAbstractions", "JET", "OhMyThreads", "Pyehtim", "StaticArrays", "StructArrays", "Test"]
25 changes: 25 additions & 0 deletions ext/ComradeBaseAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module ComradeBaseAdaptExt

using ComradeBase
using DimensionalData
using Adapt

function Adapt.adapt_structure(to, A::IntensityMap)
return IntensityMap(Adapt.adapt_structure(to, DimensionalData.data(A)),
Adapt.adapt_structure(to, axisdims(A)),
Adapt.adapt_structure(to, DimensionalData.refdims(A)),
DimensionalData.Name(name(A)))
end

function Adapt.adapt_structure(to, A::UnstructuredMap)
return UnstructuredMap(Adapt.adapt_structure(to, parent(A)),
Adapt.adapt_structure(to, axisdims(A)))
end

function Adapt.adapt_structure(to, A::ComradeBase.AbstractSingleDomain)
return rebuild(typeof(A), Adapt.adapt_structure(to, dims(A)),
executor(A),
header(A))
end

end
32 changes: 32 additions & 0 deletions ext/ComradeBaseKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module ComradeBaseKernelAbstractionsExt

using ComradeBase
using KernelAbstractions: Backend, allocate
using StructArrays

function ComradeBase.allocate_map(::Type{<:AbstractArray{T}},
g::UnstructuredDomain{D,<:Backend}) where {T,D}
return ComradeBase.UnstructuredMap(allocate(executor(g), T, size(g)), g)
end

function ComradeBase.allocate_map(::Type{<:StructArray{T}},
g::UnstructuredDomain{D,<:Backend}) where {T,D}
exec = executor(g)
arrs = StructArrays.buildfromschema(x -> allocate(exec, x, size(g)), T)
return UnstructuredMap(arrs, g)
end

function ComradeBase.allocate_map(::Type{<:AbstractArray{T}},
g::ComradeBase.AbstractRectiGrid{D,<:Backend}) where {T,D}
exec = executor(g)
return IntensityMap(allocate(exec, T, size(g)), g)
end

function ComradeBase.allocate_map(::Type{<:StructArray{T}},
g::ComradeBase.AbstractRectiGrid{D,<:Backend}) where {T,D}
exec = executor(g)
arrs = StructArrays.buildfromschema(x -> allocate(exec, x, size(g)), T)
return IntensityMap(arrs, g)
end

end
File renamed without changes.
3 changes: 2 additions & 1 deletion src/ComradeBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ export visibility,

include("interface.jl")
include("domain.jl")
include("unstructured_map.jl")
include("unstructured/domain.jl")
include("unstructured/map.jl")
include("images/images.jl")

const FluxMap2{T,N,E} = Union{IntensityMap{T,N,<:Any,E},
Expand Down
83 changes: 0 additions & 83 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,86 +335,3 @@ function DD.rebuild(::Type{<:RectiGrid}, g, executor=Serial(),
end

# Define some helpful names for ease typing
const DataNames = Union{<:NamedTuple{(:X, :Y, :T, :F)},<:NamedTuple{(:X, :Y, :F, :T)},
<:NamedTuple{(:X, :Y, :T)},<:NamedTuple{(:X, :Y, :F)},
<:NamedTuple{(:X, :Y)}}

# TODO make this play nice with dimensional data
struct UnstructuredDomain{D,E,H<:AbstractHeader} <: AbstractSingleDomain{D,E}
dims::D
executor::E
header::H
end

EnzymeRules.inactive_type(::Type{<:UnstructuredDomain}) = true

"""
UnstructuredDomain(dims::NamedTuple; executor=Serial(), header=ComradeBase.NoHeader)
Builds an unstructured grid (really a vector of points) from the dimensions `dims`.
The `executor` is used controls how the grid is computed when calling
`visibilitymap` or `intensitymap`. The default is `Serial` which mean regular CPU computations.
For threaded execution use [`ThreadsEx()`](@ref) or load `OhMyThreads.jl` to uses their schedulers.
Note that unlike `RectiGrid` which assigns dimensions to the grid points, `UnstructuredDomain`
does not. This is becuase the grid is unstructured the points are a cloud in a space
"""
function UnstructuredDomain(nt::NamedTuple; executor=Serial(), header=NoHeader())
p = StructArray(nt)
return UnstructuredDomain(p, executor, header)
end

Base.ndims(d::UnstructuredDomain) = ndims(dims(d))
Base.size(d::UnstructuredDomain) = size(dims(d))
Base.firstindex(d::UnstructuredDomain) = firstindex(dims(d))
Base.lastindex(d::UnstructuredDomain) = lastindex(dims(d))
#Make sure we actually get a tuple here
# Base.front(d::UnstructuredDomain) = UnstructuredDomain(Base.front(StructArrays.components(dims(d))), executor=executor(d), header=header(d))
# Base.eltype(d::UnstructuredDomain) = Base.eltype(dims(d))

function DD.rebuild(::Type{<:UnstructuredDomain}, g, executor=Serial(),
header=ComradeBase.NoHeader())
return UnstructuredDomain(g, executor, header)
end

Base.propertynames(g::UnstructuredDomain) = propertynames(domainpoints(g))
Base.getproperty(g::UnstructuredDomain, p::Symbol) = getproperty(domainpoints(g), p)
Base.keys(g::UnstructuredDomain) = propertynames(g)
named_dims(g::UnstructuredDomain) = StructArrays.components(dims(g))

function domainpoints(d::UnstructuredDomain)
return getfield(d, :dims)
end

#This function helps us to lookup UnstructuredDomain at a particular Ti or Fr
#visdomain[Ti=T,Fr=F] or visdomain[Ti=T] or visdomain[Fr=F] calls work.
function Base.getindex(domain::UnstructuredDomain; Ti=nothing, Fr=nothing)
points = domainpoints(domain)
indices = if Ti !== nothing && Fr !== nothing
findall(p -> (p.Ti == Ti) && (p.Fr == Fr), points)
elseif Ti !== nothing
findall(p -> (p.Ti == Ti), points)
else
findall(p -> (p.Fr == Fr), points)
end
return UnstructuredDomain(points[indices], executor(domain), header(domain))
end

function Base.summary(io::IO, g::UnstructuredDomain)
n = propertynames(domainpoints(g))
printstyled(io, ""; color=:light_black)
return print(io, "UnstructuredDomain with dims: $n")
end

function Base.show(io::IO, mime::MIME"text/plain", x::UnstructuredDomain)
println(io, "UnstructredDomain(")
println(io, "executor: $(executor(x))")
println(io, "Dimensions: ")
show(io, mime, dims(x))
return print(io, "\n)")
end

create_map(array, g::UnstructuredDomain) = UnstructuredMap(array, g)
function allocate_map(M::Type{<:A}, g::UnstructuredDomain) where {A<:AbstractArray}
return UnstructuredMap(similar(M, size(g)), g)
end
10 changes: 9 additions & 1 deletion src/images/dim_image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ DD.metadata(img::IntensityMap) = header(axisdims(img))
executor(img::IntensityMap) = executor(axisdims(img))

# TODO add this to DimensionalData directly
EnzymeRules.inactive(::typeof(DD.comparedims), args...) = nothing
EnzymeRules.inactive(::typeof(DD._broadcasted_dims), args...; kwargs...) = nothing
EnzymeRules.inactive(::typeof(DD.Dimensions.comparedims), args...; kwargs...) = nothing
EnzymeRules.inactive(::typeof(DD.Dimensions._comparedims), args...; kwargs...) = nothing

# We need this to make sure IntensityMap works correctly on the GPU
function Base.copyto!(dest::DimensionalData.AbstractDimArray, bc::Broadcast.Broadcasted)
copyto!(baseimage(dest), bc)
return dest
end

# For the `IntensityMap` nothing is AD-able except the data so
# let's tell Enzyme this
Expand Down
4 changes: 2 additions & 2 deletions src/images/methods.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
domainpoints(k::IntensityMap)
Returns the grid the `IntensityMap` is defined as. Note that this is unallocating
since it lazily computes the grid. The grid is an example of a DimArray and works similarly.
Returns the grid the `IntensityMap` is defined as. Note that this is nonallocating
since it lazily computes the grid.
This is useful for broadcasting a model across an abritrary grid.
"""
domainpoints(img::IntensityMap) = domainpoints(axisdims(img))
Expand Down
50 changes: 50 additions & 0 deletions src/unstructured/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Base.Broadcast: Broadcasted, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle,
Style

Base.BroadcastStyle(::Type{<:UnstructuredMap}) = Broadcast.ArrayStyle{UnstructuredMap}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{UnstructuredMap}},
::Type{ElType}) where {ElType}
# Scan inputs for the time and sites
sarr = find_ustr(bc)
return UnstructuredMap(similar(parent(sarr), ElType), axisdims(sarr))
end

# function UnstructuredStyle(a::BroadcastStyle, b::BroadcastStyle)
# inner_style = BroadcastStyle(a, b)
# if inner_style isa Broadcast.Unknown
# return Broadcast.Unknown()
# else
# return UnstructuredStyle(inner_style)
# end
# end

# BroadcastStyle(::UnstructuredStyle, ::Base.Broadcast.Unknown) = Unknown()
# BroadcastStyle(::Base.Broadcast.Unknown, ::UnstructuredStyle) = Unknown()
# BroadcastStyle(::UnstructuredStyle{A}, ::UnstructuredStyle{B}) where {A, B} = UnstructuredStyle(A(), B())
# BroadcastStyle(::UnstructuredStyle{A}, b::Style) where {A} = UnstructuredStyle(A(), b)
# BroadcastStyle(a::Style, ::UnstructuredStyle{B}) where {B} = UnstructuredStyle(a, B())
# BroadcastStyle(::UnstructuredStyle{A}, b::Style{Tuple}) where {A} = UnstructuredStyle(A(), b)
# BroadcastStyle(a::Style{Tuple}, ::UnstructuredStyle{B}) where {B} = UnstructuredStyle(a, B())
# function Base.similar(bc::Broadcasted{UnstructuredStyle},
# ::Type{ElType}) where {ElType}
# # Scan inputs for the time and sites
# sarr = find_ustr(bc)
# return UnstructuredMap(similar(parent(sarr), ElType), axisdims(sarr))
# end

function Base.copyto!(dest::UnstructuredMap, bc::Broadcast.Broadcasted)
copyto!(baseimage(dest), bc)
return dest
end

find_ustr(bc::Broadcasted) = find_ustr(bc.args)
find_ustr(args::Tuple) = find_ustr(find_ustr(args[1]), Base.tail(args))
find_ustr(x) = x
find_ustr(::Tuple{}) = nothing
find_ustr(x::UnstructuredMap, rest) = x
find_ustr(::Any, rest) = find_ustr(rest)

domainpoints(x::UnstructuredMap) = domainpoints(axisdims(x))

Base.propertynames(x::UnstructuredMap) = propertynames(axisdims(x))
Base.getproperty(x::UnstructuredMap, s::Symbol) = getproperty(axisdims(x), s)
83 changes: 83 additions & 0 deletions src/unstructured/domain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
const DataNames = Union{<:NamedTuple{(:X, :Y, :T, :F)},<:NamedTuple{(:X, :Y, :F, :T)},
<:NamedTuple{(:X, :Y, :T)},<:NamedTuple{(:X, :Y, :F)},
<:NamedTuple{(:X, :Y)}}

# TODO make this play nice with dimensional data
struct UnstructuredDomain{D,E,H<:AbstractHeader} <: AbstractSingleDomain{D,E}
dims::D
executor::E
header::H
end

EnzymeRules.inactive_type(::Type{<:UnstructuredDomain}) = true

"""
UnstructuredDomain(dims::NamedTuple; executor=Serial(), header=ComradeBase.NoHeader)
Builds an unstructured grid (really a vector of points) from the dimensions `dims`.
The `executor` is used controls how the grid is computed when calling
`visibilitymap` or `intensitymap`. The default is `Serial` which mean regular CPU computations.
For threaded execution use [`ThreadsEx()`](@ref) or load `OhMyThreads.jl` to uses their schedulers.
Note that unlike `RectiGrid` which assigns dimensions to the grid points, `UnstructuredDomain`
does not. This is becuase the grid is unstructured the points are a cloud in a space
"""
function UnstructuredDomain(nt::NamedTuple; executor=Serial(), header=NoHeader())
p = StructArray(nt)
return UnstructuredDomain(p, executor, header)
end

Base.ndims(d::UnstructuredDomain) = ndims(dims(d))
Base.size(d::UnstructuredDomain) = size(dims(d))
Base.firstindex(d::UnstructuredDomain) = firstindex(dims(d))
Base.lastindex(d::UnstructuredDomain) = lastindex(dims(d))
#Make sure we actually get a tuple here
# Base.front(d::UnstructuredDomain) = UnstructuredDomain(Base.front(StructArrays.components(dims(d))), executor=executor(d), header=header(d))
# Base.eltype(d::UnstructuredDomain) = Base.eltype(dims(d))

function DD.rebuild(::Type{<:UnstructuredDomain}, g, executor=Serial(),
header=ComradeBase.NoHeader())
return UnstructuredDomain(g, executor, header)
end

Base.propertynames(g::UnstructuredDomain) = propertynames(domainpoints(g))
Base.getproperty(g::UnstructuredDomain, p::Symbol) = getproperty(domainpoints(g), p)
Base.keys(g::UnstructuredDomain) = propertynames(g)
named_dims(g::UnstructuredDomain) = StructArrays.components(dims(g))

function domainpoints(d::UnstructuredDomain)
return getfield(d, :dims)
end

#This function helps us to lookup UnstructuredDomain at a particular Ti or Fr
#visdomain[Ti=T,Fr=F] or visdomain[Ti=T] or visdomain[Fr=F] calls work.
function Base.getindex(domain::UnstructuredDomain; Ti=nothing, Fr=nothing)
points = domainpoints(domain)
indices = if Ti !== nothing && Fr !== nothing
findall(p -> (p.Ti == Ti) && (p.Fr == Fr), points)
elseif Ti !== nothing
findall(p -> (p.Ti == Ti), points)
else
findall(p -> (p.Fr == Fr), points)
end
return UnstructuredDomain(points[indices], executor(domain), header(domain))
end

function Base.summary(io::IO, g::UnstructuredDomain)
n = propertynames(domainpoints(g))
printstyled(io, ""; color=:light_black)
return print(io, "UnstructuredDomain with dims: $n")
end

function Base.show(io::IO, mime::MIME"text/plain", x::UnstructuredDomain)
println(io, "UnstructredDomain(")
println(io, "executor: $(executor(x))")
println(io, "Dimensions: ")
show(io, mime, dims(x))
return print(io, "\n)")
end

create_map(array, g::UnstructuredDomain) = UnstructuredMap(array, g)
function allocate_map(M::Type{<:AbstractArray}, g::UnstructuredDomain)
return UnstructuredMap(similar(M, size(g)), g)
end
22 changes: 2 additions & 20 deletions src/unstructured_map.jl → src/unstructured/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,6 @@ function Base.similar(m::UnstructuredMap, ::Type{S}) where {S}
return UnstructuredMap(similar(parent(m), S), axisdims(m))
end

Base.BroadcastStyle(::Type{<:UnstructuredMap}) = Broadcast.ArrayStyle{UnstructuredMap}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{UnstructuredMap}},
::Type{ElType}) where {ElType}
# Scan inputs for the time and sites
sarr = find_ustr(bc)
return UnstructuredMap(similar(parent(sarr), ElType), axisdims(sarr))
end

find_ustr(bc::Broadcast.Broadcasted) = find_ustr(bc.args)
find_ustr(args::Tuple) = find_ustr(find_ustr(args[1]), Base.tail(args))
find_ustr(x) = x
find_ustr(::Tuple{}) = nothing
find_ustr(x::UnstructuredMap, rest) = x
find_ustr(::Any, rest) = find_ustr(rest)

domainpoints(x::UnstructuredMap) = domainpoints(axisdims(x))

Base.propertynames(x::UnstructuredMap) = propertynames(axisdims(x))
Base.getproperty(x::UnstructuredMap, s::Symbol) = getproperty(axisdims(x), s)

function Base.view(x::UnstructuredMap, I)
dims = axisdims(x)
g = domainpoints(dims)
Expand Down Expand Up @@ -99,3 +79,5 @@ for s in schedulers
end
end
end

include("broadcast.jl")

0 comments on commit e65308d

Please sign in to comment.