Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Aug 17, 2024
2 parents 1ec556a + 924908a commit 92ff502
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComradeBase"
uuid = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
authors = ["Paul Tiede <[email protected]> and contributors"]
version = "0.8.0"
version = "0.8.1"

[deps]
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Expand All @@ -26,7 +26,7 @@ DimensionalData = "0.27"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12"
EnzymeCore = "0.6, 0.7"
OhMyThreads = "0.5"
OhMyThreads = "0.5, 0.6"
PolarizedTypes = "0.1"
PrecompileTools = "1"
Reexport = "1"
Expand Down
13 changes: 13 additions & 0 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,19 @@ function domainpoints(d::UnstructuredDomain)
return getfield(d, :dims)
end

#This function helps us to lookup UnstructuredDomain at a particular Ti or Fr
#visdomain[Ti=T,Fr=F] or visdomain[Ti=T] or visdomain[Fr=F] calls work.
function Base.getindex(domain::UnstructuredDomain; Ti=nothing, Fr=nothing)
points = domainpoints(domain)
indices = if Ti !== nothing && Fr !== nothing
findall(p -> (p.Ti == Ti) && (p.Fr == Fr), points)
elseif Ti !== nothing
findall(p -> (p.Ti == Ti), points)
else
findall(p -> (p.Fr == Fr), points)
end
return UnstructuredDomain(points[indices], executor(domain), header(domain))
end

function Base.summary(io::IO, g::UnstructuredDomain)
n = propertynames(domainpoints(g))
Expand Down
118 changes: 118 additions & 0 deletions test/multidomain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
function domain4d(N, Nt, Nf)
U_vals = range(-10e9, 10e9, length=N)
U_vals = U_vals' .* ones(N)
V_vals = range(-10e9, 10e9, length=N)
V_vals = V_vals' .* ones(N)
U_vals = U_vals'

# Flatten the U and V grids
U_final = vec(U_vals)
V_final = vec(V_vals)

ti = 10 * rand(Nt) |> sort
fr = 1e11 * rand(Nf) |> sort

# Repeat U and V to match Ti dimensions
U_repeated = repeat(U_final, outer=(length(ti)))
V_repeated = repeat(V_final, outer=(length(ti)))
Ti_repeated = repeat(ti, inner=(Int(length(U_final))))

# Repeat U and V and Ti to match Fr dimensions
U_repeated = repeat(U_repeated, outer=(length(fr)))
V_repeated = repeat(V_repeated, outer=(length(fr)))
Ti_repeated = repeat(Ti_repeated, outer=(length(fr)))
Fr_repeated = repeat(fr, inner=(Int(length(U_repeated)/length(fr))))
visdomain = UnstructuredDomain((;U=U_repeated, V=V_repeated, Ti=Ti_repeated, Fr=Fr_repeated))

C1 = true
for ti_point in ti
for fr_point in fr
f = visdomain[Ti=ti_point, Fr=fr_point]
g = UnstructuredDomain((;U=U_final, V=V_final, Ti=vcat(fill(ti_point, length(U_final))), Fr=vcat(fill(fr_point, length(U_final)))))
C1 = C1 && (domainpoints(f) == domainpoints(g))
end
end

#Switch Ti and Fr order
C2 = true
for ti_point in ti
for fr_point in fr
f = visdomain[Fr=fr_point, Ti=ti_point]
g = UnstructuredDomain((;U=U_final, V=V_final, Ti=vcat(fill(ti_point, length(U_final))), Fr=vcat(fill(fr_point, length(U_final)))))
C2 = C2 && (domainpoints(f) == domainpoints(g))
end
end

return C1, C2
end

function domain3df(N, Nf)
U_vals = range(-10e9, 10e9, length=N)
U_vals = U_vals' .* ones(N)
V_vals = range(-10e9, 10e9, length=N)
V_vals = V_vals' .* ones(N)
U_vals = U_vals'

# Flatten the U and V grids
U_final = vec(U_vals)
V_final = vec(V_vals)

fr = 1e11 * rand(Nf) |> sort

# Repeat U and V to match Fr dimensions
U_repeated = repeat(U_final, outer=(length(fr)))
V_repeated = repeat(V_final, outer=(length(fr)))
Fr_repeated = repeat(fr, inner=(Int(length(U_final))))

visdomain = UnstructuredDomain((;U=U_repeated, V=V_repeated, Fr=Fr_repeated))

C = true
for fr_point in fr
f = visdomain[Fr=fr_point]
g = UnstructuredDomain((;U=U_final, V=V_final, Fr=vcat(fill(fr_point, length(U_final)))))
C = C && (domainpoints(f) == domainpoints(g))
end

return C
end

function domain3dt(N, Nt)
U_vals = range(-10e9, 10e9, length=N)
U_vals = U_vals' .* ones(N)
V_vals = range(-10e9, 10e9, length=N)
V_vals = V_vals' .* ones(N)
U_vals = U_vals'

# Flatten the U and V grids
U_final = vec(U_vals)
V_final = vec(V_vals)

ti = 10 * rand(Nt) |> sort

# Repeat U and V to match Ti dimensions
U_repeated = repeat(U_final, outer=(length(ti)))
V_repeated = repeat(V_final, outer=(length(ti)))
Ti_repeated = repeat(ti, inner=(Int(length(U_final))))

visdomain = UnstructuredDomain((;U=U_repeated, V=V_repeated, Ti=Ti_repeated))

C = true
for ti_point in ti
f = visdomain[Ti=ti_point]
g = UnstructuredDomain((;U=U_final, V=V_final, Ti=vcat(fill(ti_point, length(U_final)))))
C = C && (domainpoints(f) == domainpoints(g))
end

return C
end


@testset "Test getindex for visdomain" begin
C1, C2 = domain4d(64,10,4)
@test C1
@test C2
C3 = domain3dt(64,10)
@test C3
C4 = domain3df(64,4)
@test C4
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ import DimensionalData as DD
include(joinpath(@__DIR__, "images.jl"))
include(joinpath(@__DIR__, "visibilities.jl"))
include(joinpath(@__DIR__, "executors.jl"))
include(joinpath(@__DIR__, "multidomain.jl"))
end

0 comments on commit 92ff502

Please sign in to comment.