-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #37 from ptiede/ptiede-gpuexec
Make ComradeBase gpu compatible
- Loading branch information
Showing
11 changed files
with
213 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module ComradeBaseAdaptExt | ||
|
||
using ComradeBase | ||
using DimensionalData | ||
using Adapt | ||
|
||
function Adapt.adapt_structure(to, A::IntensityMap) | ||
return IntensityMap(Adapt.adapt_structure(to, DimensionalData.data(A)), | ||
Adapt.adapt_structure(to, axisdims(A)), | ||
Adapt.adapt_structure(to, DimensionalData.refdims(A)), | ||
DimensionalData.Name(name(A))) | ||
end | ||
|
||
function Adapt.adapt_structure(to, A::UnstructuredMap) | ||
return UnstructuredMap(Adapt.adapt_structure(to, parent(A)), | ||
Adapt.adapt_structure(to, axisdims(A))) | ||
end | ||
|
||
function Adapt.adapt_structure(to, A::ComradeBase.AbstractSingleDomain) | ||
return rebuild(typeof(A), Adapt.adapt_structure(to, dims(A)), | ||
executor(A), | ||
header(A)) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
module ComradeBaseKernelAbstractionsExt | ||
|
||
using ComradeBase | ||
using KernelAbstractions: Backend, allocate | ||
using StructArrays | ||
|
||
function ComradeBase.allocate_map(::Type{<:AbstractArray{T}}, | ||
g::UnstructuredDomain{D,<:Backend}) where {T,D} | ||
return ComradeBase.UnstructuredMap(allocate(executor(g), T, size(g)), g) | ||
end | ||
|
||
function ComradeBase.allocate_map(::Type{<:StructArray{T}}, | ||
g::UnstructuredDomain{D,<:Backend}) where {T,D} | ||
exec = executor(g) | ||
arrs = StructArrays.buildfromschema(x -> allocate(exec, x, size(g)), T) | ||
return UnstructuredMap(arrs, g) | ||
end | ||
|
||
function ComradeBase.allocate_map(::Type{<:AbstractArray{T}}, | ||
g::ComradeBase.AbstractRectiGrid{D,<:Backend}) where {T,D} | ||
exec = executor(g) | ||
return IntensityMap(allocate(exec, T, size(g)), g) | ||
end | ||
|
||
function ComradeBase.allocate_map(::Type{<:StructArray{T}}, | ||
g::ComradeBase.AbstractRectiGrid{D,<:Backend}) where {T,D} | ||
exec = executor(g) | ||
arrs = StructArrays.buildfromschema(x -> allocate(exec, x, size(g)), T) | ||
return IntensityMap(arrs, g) | ||
end | ||
|
||
end |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using Base.Broadcast: Broadcasted, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, | ||
Style | ||
|
||
Base.BroadcastStyle(::Type{<:UnstructuredMap}) = Broadcast.ArrayStyle{UnstructuredMap}() | ||
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{UnstructuredMap}}, | ||
::Type{ElType}) where {ElType} | ||
# Scan inputs for the time and sites | ||
sarr = find_ustr(bc) | ||
return UnstructuredMap(similar(parent(sarr), ElType), axisdims(sarr)) | ||
end | ||
|
||
# function UnstructuredStyle(a::BroadcastStyle, b::BroadcastStyle) | ||
# inner_style = BroadcastStyle(a, b) | ||
# if inner_style isa Broadcast.Unknown | ||
# return Broadcast.Unknown() | ||
# else | ||
# return UnstructuredStyle(inner_style) | ||
# end | ||
# end | ||
|
||
# BroadcastStyle(::UnstructuredStyle, ::Base.Broadcast.Unknown) = Unknown() | ||
# BroadcastStyle(::Base.Broadcast.Unknown, ::UnstructuredStyle) = Unknown() | ||
# BroadcastStyle(::UnstructuredStyle{A}, ::UnstructuredStyle{B}) where {A, B} = UnstructuredStyle(A(), B()) | ||
# BroadcastStyle(::UnstructuredStyle{A}, b::Style) where {A} = UnstructuredStyle(A(), b) | ||
# BroadcastStyle(a::Style, ::UnstructuredStyle{B}) where {B} = UnstructuredStyle(a, B()) | ||
# BroadcastStyle(::UnstructuredStyle{A}, b::Style{Tuple}) where {A} = UnstructuredStyle(A(), b) | ||
# BroadcastStyle(a::Style{Tuple}, ::UnstructuredStyle{B}) where {B} = UnstructuredStyle(a, B()) | ||
# function Base.similar(bc::Broadcasted{UnstructuredStyle}, | ||
# ::Type{ElType}) where {ElType} | ||
# # Scan inputs for the time and sites | ||
# sarr = find_ustr(bc) | ||
# return UnstructuredMap(similar(parent(sarr), ElType), axisdims(sarr)) | ||
# end | ||
|
||
function Base.copyto!(dest::UnstructuredMap, bc::Broadcast.Broadcasted) | ||
copyto!(baseimage(dest), bc) | ||
return dest | ||
end | ||
|
||
find_ustr(bc::Broadcasted) = find_ustr(bc.args) | ||
find_ustr(args::Tuple) = find_ustr(find_ustr(args[1]), Base.tail(args)) | ||
find_ustr(x) = x | ||
find_ustr(::Tuple{}) = nothing | ||
find_ustr(x::UnstructuredMap, rest) = x | ||
find_ustr(::Any, rest) = find_ustr(rest) | ||
|
||
domainpoints(x::UnstructuredMap) = domainpoints(axisdims(x)) | ||
|
||
Base.propertynames(x::UnstructuredMap) = propertynames(axisdims(x)) | ||
Base.getproperty(x::UnstructuredMap, s::Symbol) = getproperty(axisdims(x), s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
const DataNames = Union{<:NamedTuple{(:X, :Y, :T, :F)},<:NamedTuple{(:X, :Y, :F, :T)}, | ||
<:NamedTuple{(:X, :Y, :T)},<:NamedTuple{(:X, :Y, :F)}, | ||
<:NamedTuple{(:X, :Y)}} | ||
|
||
# TODO make this play nice with dimensional data | ||
struct UnstructuredDomain{D,E,H<:AbstractHeader} <: AbstractSingleDomain{D,E} | ||
dims::D | ||
executor::E | ||
header::H | ||
end | ||
|
||
EnzymeRules.inactive_type(::Type{<:UnstructuredDomain}) = true | ||
|
||
""" | ||
UnstructuredDomain(dims::NamedTuple; executor=Serial(), header=ComradeBase.NoHeader) | ||
Builds an unstructured grid (really a vector of points) from the dimensions `dims`. | ||
The `executor` is used controls how the grid is computed when calling | ||
`visibilitymap` or `intensitymap`. The default is `Serial` which mean regular CPU computations. | ||
For threaded execution use [`ThreadsEx()`](@ref) or load `OhMyThreads.jl` to uses their schedulers. | ||
Note that unlike `RectiGrid` which assigns dimensions to the grid points, `UnstructuredDomain` | ||
does not. This is becuase the grid is unstructured the points are a cloud in a space | ||
""" | ||
function UnstructuredDomain(nt::NamedTuple; executor=Serial(), header=NoHeader()) | ||
p = StructArray(nt) | ||
return UnstructuredDomain(p, executor, header) | ||
end | ||
|
||
Base.ndims(d::UnstructuredDomain) = ndims(dims(d)) | ||
Base.size(d::UnstructuredDomain) = size(dims(d)) | ||
Base.firstindex(d::UnstructuredDomain) = firstindex(dims(d)) | ||
Base.lastindex(d::UnstructuredDomain) = lastindex(dims(d)) | ||
#Make sure we actually get a tuple here | ||
# Base.front(d::UnstructuredDomain) = UnstructuredDomain(Base.front(StructArrays.components(dims(d))), executor=executor(d), header=header(d)) | ||
# Base.eltype(d::UnstructuredDomain) = Base.eltype(dims(d)) | ||
|
||
function DD.rebuild(::Type{<:UnstructuredDomain}, g, executor=Serial(), | ||
header=ComradeBase.NoHeader()) | ||
return UnstructuredDomain(g, executor, header) | ||
end | ||
|
||
Base.propertynames(g::UnstructuredDomain) = propertynames(domainpoints(g)) | ||
Base.getproperty(g::UnstructuredDomain, p::Symbol) = getproperty(domainpoints(g), p) | ||
Base.keys(g::UnstructuredDomain) = propertynames(g) | ||
named_dims(g::UnstructuredDomain) = StructArrays.components(dims(g)) | ||
|
||
function domainpoints(d::UnstructuredDomain) | ||
return getfield(d, :dims) | ||
end | ||
|
||
#This function helps us to lookup UnstructuredDomain at a particular Ti or Fr | ||
#visdomain[Ti=T,Fr=F] or visdomain[Ti=T] or visdomain[Fr=F] calls work. | ||
function Base.getindex(domain::UnstructuredDomain; Ti=nothing, Fr=nothing) | ||
points = domainpoints(domain) | ||
indices = if Ti !== nothing && Fr !== nothing | ||
findall(p -> (p.Ti == Ti) && (p.Fr == Fr), points) | ||
elseif Ti !== nothing | ||
findall(p -> (p.Ti == Ti), points) | ||
else | ||
findall(p -> (p.Fr == Fr), points) | ||
end | ||
return UnstructuredDomain(points[indices], executor(domain), header(domain)) | ||
end | ||
|
||
function Base.summary(io::IO, g::UnstructuredDomain) | ||
n = propertynames(domainpoints(g)) | ||
printstyled(io, "│ "; color=:light_black) | ||
return print(io, "UnstructuredDomain with dims: $n") | ||
end | ||
|
||
function Base.show(io::IO, mime::MIME"text/plain", x::UnstructuredDomain) | ||
println(io, "UnstructredDomain(") | ||
println(io, "executor: $(executor(x))") | ||
println(io, "Dimensions: ") | ||
show(io, mime, dims(x)) | ||
return print(io, "\n)") | ||
end | ||
|
||
create_map(array, g::UnstructuredDomain) = UnstructuredMap(array, g) | ||
function allocate_map(M::Type{<:AbstractArray}, g::UnstructuredDomain) | ||
return UnstructuredMap(similar(M, size(g)), g) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters