From e8c45834d389b1e7bdd854ffbff5dcc7b0b91e94 Mon Sep 17 00:00:00 2001 From: Jack Grogan Date: Fri, 1 Nov 2024 16:35:16 +0000 Subject: [PATCH 1/5] Implementation of raytracing in implicitBVH.jl --- benchmark/bvh_rays.jl | 76 +++++++++++++++ prototype/Project.toml | 1 - prototype/raytracing.jl | 49 ++++++++++ src/ImplicitBVH.jl | 4 +- src/rays.jl | 81 ++++++++++++++++ src/raytrace/raytrace.jl | 141 ++++++++++++++++++++++++++++ src/raytrace/raytrace_cpu.jl | 177 +++++++++++++++++++++++++++++++++++ src/raytrace/raytrace_gpu.jl | 172 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 110 ++++++++++++++++++++++ 9 files changed, 809 insertions(+), 2 deletions(-) create mode 100644 benchmark/bvh_rays.jl create mode 100644 prototype/raytracing.jl create mode 100644 src/rays.jl create mode 100644 src/raytrace/raytrace.jl create mode 100644 src/raytrace/raytrace_cpu.jl create mode 100644 src/raytrace/raytrace_gpu.jl diff --git a/benchmark/bvh_rays.jl b/benchmark/bvh_rays.jl new file mode 100644 index 0000000..af5f0c8 --- /dev/null +++ b/benchmark/bvh_rays.jl @@ -0,0 +1,76 @@ +# File : bvh_build.jl +# License: MIT +# Author : Andrei Leonard Nicusan +# Date : 15.12.2022 + + +using ImplicitBVH +using ImplicitBVH: BSphere, BBox + +using Random +Random.seed!(42) + +using MeshIO +using FileIO + +using BenchmarkTools +using Profile +using PProf + + +# Types used +const LeafType = BBox{Float32} + + + +num_bvs = 1_000_000 + +bvs = [LeafType(rand(3, 3)) for _ in 1:num_bvs] +points = rand(Float32, 3, num_bvs) +directions = rand(Float32, 3, num_bvs) + + +# Example usage +function check_intersections!(intersections, bvs, points, directions) + @inbounds for i in eachindex(intersections) + intersections[i] = isintersection(bvs[i], @view(points[:, i]), @view(directions[:, i])) + end + nothing +end + +intersections = Vector{Bool}(undef, length(bvs)) +check_intersections!(intersections, bvs, points, directions) + + +println("Benchmarking $(length(bvs)) intersections:") +display(@benchmark check_intersections!(intersections, bvs, points, directions)) + + +# # Collect a pprof profile of the complete build +# Profile.clear() +# @profile begin +# for _ in 1:1000 +# check_intersections!(intersections, bvs, points, directions) +# end +# end +# +# +# # Export pprof profile and open interactive profiling web interface. +# pprof(; out="bvh_rays.pb.gz") + + +# Test for some coding mistakes +using Test +Test.detect_unbound_args(ImplicitBVH, recursive = true) +Test.detect_ambiguities(ImplicitBVH, recursive = true) + + +# More complete report on type stabilities +using JET +JET.@report_opt check_intersections!(intersections, bvs, points, directions) + + +# using Profile +# BVH(bounding_spheres, NodeType, MortonType) +# Profile.clear_malloc_data() +# BVH(bounding_spheres, NodeType, MortonType) diff --git a/prototype/Project.toml b/prototype/Project.toml index 21775d9..f075798 100644 --- a/prototype/Project.toml +++ b/prototype/Project.toml @@ -4,7 +4,6 @@ Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" ImplicitBVH = "932a18dc-bb55-4cd5-bdd6-1368ec9cea29" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" PProf = "e4faabce-9ead-11e9-39d9-4379958e3056" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/prototype/raytracing.jl b/prototype/raytracing.jl new file mode 100644 index 0000000..cc746b7 --- /dev/null +++ b/prototype/raytracing.jl @@ -0,0 +1,49 @@ + +using Random + +using BenchmarkTools + +using ImplicitBVH +using ImplicitBVH: BSphere, BBox + +using Profile +using PProf + + +Random.seed!(0) + +num_bvs = 100_000 +bvs = map(BSphere{Float32}, [6 * rand(3) .+ rand(3, 3) for _ in 1:num_bvs]) + +options = BVHOptions(block_size=128, num_threads=8) +bvh = BVH(bvs, BBox{Float32}, UInt32, 1, options=options) + + +num_rays = 100_000 +points = 100 * rand(3, num_rays) .+ rand(Float32, 3, num_rays) +directions = rand(Float32, 3, num_rays) + + +bvtt1, bvtt2, num_bvtt = ImplicitBVH.initial_bvtt( + bvh, + points, + directions, + 2, + nothing, + options, +) + +bvt = traverse_rays(bvh, points, directions) + +for (ibv, iray) in bvt.contacts + @assert ImplicitBVH.isintersection(bvs[ibv], points[:, iray], directions[:, iray]) +end + + +# TODO Brute force tests for comparison +# TODO document rays and integrate them in bounding_volumes.jl +# TODO add example to README +# TODO include in docs/ +# TODO maybe a pretty image / render? +# TODO benchmark against a standard raytracer? Check Chitalu's paper; the Stanford bunny example + diff --git a/src/ImplicitBVH.jl b/src/ImplicitBVH.jl index ef86412..a6412be 100644 --- a/src/ImplicitBVH.jl +++ b/src/ImplicitBVH.jl @@ -7,7 +7,7 @@ module ImplicitBVH # Functionality exported by this package by default -export BVH, BVHTraversal, BVHOptions, traverse, default_start_level +export BVH, BVHTraversal, BVHOptions, traverse, traverse_rays, default_start_level export ImplicitTree, memory_index, level_indices, isvirtual @@ -28,8 +28,10 @@ include("utils.jl") include("morton.jl") include("implicit_tree.jl") include("bounding_volumes.jl") +include("rays.jl") include("build.jl") include("traverse/traverse.jl") +include("raytrace/raytrace.jl") include("utils_post.jl") end # module ImplicitBVH diff --git a/src/rays.jl b/src/rays.jl new file mode 100644 index 0000000..5bf683b --- /dev/null +++ b/src/rays.jl @@ -0,0 +1,81 @@ +""" + isintersection(b::BBox, p::Type{3, T}, d::Type{3, T}) + isintersection(s::BSphere, p::Type{3, T}, d::Type{3, T}) + +Return True if ray intersects a sphere or box +""" + +# will go into bounding volumes +function isintersection end + + +@inline function isintersection(b::BBox, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + T = eltype(d) + + @inbounds begin + inv_d = (one(T) / d[1], one(T) / d[2], one(T) / d[3]) + + # Set x bounds + t_bound_x1 = (b.lo[1] - p[1]) * inv_d[1] + t_bound_x2 = (b.up[1] - p[1]) * inv_d[1] + + tmin = minimum2(t_bound_x1, t_bound_x2) + tmax = maximum2(t_bound_x1, t_bound_x2) + + # Set y bounds + t_bound_y1 = (b.lo[2] - p[2]) * inv_d[2] + t_bound_y2 = (b.up[2] - p[2]) * inv_d[2] + + tmin = maximum2(tmin, minimum2(t_bound_y1, t_bound_y2)) + tmax = minimum2(tmax, maximum2(t_bound_y1, t_bound_y2)) + + # Set z bounds + t_bound_z1 = (b.lo[3] - p[3]) * inv_d[3] + t_bound_z2 = (b.up[3] - p[3]) * inv_d[3] + + tmin = maximum2(tmin, minimum2(t_bound_z1, t_bound_z2)) + tmax = minimum2(tmax, maximum2(t_bound_z1, t_bound_z2)) + end + + # If condition satisfied ray intersects box. tmax >= 0 + # ensure only forwards intersections are counted + (tmin <= tmax) && (tmax >= 0) +end + + +@inline function isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + @inbounds begin + a = dot3(d, d) + b = 2 * ( + (p[1] - s.x[1]) * d[1] + + (p[2] - s.x[2]) * d[2] + + (p[3] - s.x[3]) * d[3] + ) + c = ( + (p[1] - s.x[1]) * (p[1] - s.x[1]) + + (p[2] - s.x[2]) * (p[2] - s.x[2]) + + (p[3] - s.x[3]) * (p[3] - s.x[3]) + ) - s.r * s.r + end + + discriminant = b * b - 4 * a * c + + if discriminant >= 0 + # Ensure only forwards intersections are counted + return 0 >= b * c + else + return false + end +end \ No newline at end of file diff --git a/src/raytrace/raytrace.jl b/src/raytrace/raytrace.jl new file mode 100644 index 0000000..19986cf --- /dev/null +++ b/src/raytrace/raytrace.jl @@ -0,0 +1,141 @@ +""" + +""" +function traverse_rays( + bvh::BVH, + points::AbstractArray, + directions::AbstractArray, + start_level::Int=1, + cache::Union{Nothing, BVHTraversal}=nothing; + options=BVHOptions(), +) + # Correctness checks + @boundscheck begin + # TODO check the container type of bvh leaves / nodes is the same as for points and directions + @argcheck bvh.tree.levels >= start_level >= bvh.built_level + @argcheck size(points, 1) == size(directions, 1) == 3 + @argcheck size(points, 2) == size(directions, 2) + end + + num_rays = size(points, 2) + + # Get index type from exemplar + I = get_index_type(options) + + # No intersections for no rays + if num_rays == 0 + return BVHTraversal(start_level, 0, 0, + similar(bvh.nodes, IndexPair{I}, 0), + similar(bvh.nodes, IndexPair{I}, 0)) + end + + # Allocate and add all possible BVTT contact pairs to start with + bvtt1, bvtt2, num_bvtt = initial_bvtt(bvh, points, directions, start_level, cache, options) + num_checks = num_bvtt + + # For GPUs we need an additional global offset to coordinate writing results + dst_offsets = if bvtt1 isa AbstractGPUVector + backend = get_backend(bvtt1) + KernelAbstractions.zeros(backend, I, Int(bvh.tree.levels)) + else + nothing + end + + level = start_level + while level < bvh.tree.levels + # We can have maximum 2 new checks per BV-ray-pair; resize destination BVTT accordingly + length(bvtt2) < 2 * num_bvtt && resize!(bvtt2, 2 * num_bvtt) + + # Check intersections in bvtt1 and add future checks in bvtt2 + num_bvtt = traverse_rays_nodes!(bvh, points, directions, + bvtt1, bvtt2, num_bvtt, + dst_offsets, level, options) + num_checks += num_bvtt + + # Swap source and destination buffers for next iteration + bvtt1, bvtt2 = bvtt2, bvtt1 + level += 1 + end + + # Arrived at final leaf level, now populating contact list + length(bvtt2) < num_bvtt && resize!(bvtt2, num_bvtt) + num_bvtt = traverse_rays_leaves!(bvh, points, directions, + bvtt1, bvtt2, num_bvtt, + dst_offsets, options) + + # Return contact list and the other buffer as possible cache + BVHTraversal(start_level, num_checks, num_bvtt, bvtt2, bvtt1) +end + + +function initial_bvtt( + bvh::BVH, + points::AbstractArray, + directions::AbstractArray, + start_level, + cache, + options, +) + num_rays = size(points, 2) + + # Get index type from exemplar + index_type = typeof(options.index_exemplar) + + # Generate all possible contact checks between all nodes at the given start_level and all rays + level_nodes = pow2(start_level - 1) + + # Number of real nodes at the given start_level and number of checks we'll do + num_real = level_nodes - bvh.tree.virtual_leaves >> (bvh.tree.levels - start_level) + level_checks = num_real * num_rays + + # If we're not at leaf-level, allocate enough memory for next BVTT expansion + if start_level == bvh.tree.levels + initial_number = level_checks + else + initial_number = 2 * level_checks + end + + # Reuse cache if given + if isnothing(cache) + bvtt1 = similar(bvh.nodes, IndexPair{index_type}, initial_number) + bvtt2 = similar(bvh.nodes, IndexPair{index_type}, initial_number) + else + @argcheck eltype(cache.cache1) === IndexPair{index_type} + @argcheck eltype(cache.cache2) === IndexPair{index_type} + + bvtt1 = cache.cache1 + bvtt2 = cache.cache2 + + length(bvtt1) < initial_number && resize!(bvtt1, initial_number) + length(bvtt2) < initial_number && resize!(bvtt2, initial_number) + end + + # Insert all checks to do at this level + backend = get_backend(bvtt1) + if backend isa GPU + # GPU version with the two for loops (see CPU) linearised + AK.foreachindex(1:num_real * num_rays, backend, block_size=options.block_size) do i + irow, icol = divrem(i - 1, num_rays) + bvtt1[i] = (irow + level_nodes, icol + 1) + end + else + # CPU initial checks; this uses such simple instructions that single threading is fastest + num_bvtt = 0 + @inbounds for i in level_nodes:level_nodes + num_real - 1 + # Node-node pair checks + for j in 1:num_rays + num_bvtt += 1 + bvtt1[num_bvtt] = (i, j) + end + end + end + + bvtt1, bvtt2, level_checks +end + + +# Traversal implementations +include("raytrace_cpu.jl") +include("raytrace_gpu.jl") + + diff --git a/src/raytrace/raytrace_cpu.jl b/src/raytrace/raytrace_cpu.jl new file mode 100644 index 0000000..1d990ba --- /dev/null +++ b/src/raytrace/raytrace_cpu.jl @@ -0,0 +1,177 @@ +function traverse_rays_nodes!(bvh, points, directions, src, dst, num_src, ::Nothing, level, options) + # Traverse nodes when level is above leaves + + # Compute number of virtual elements before this level to skip when computing the memory index + virtual_nodes_level = bvh.tree.virtual_leaves >> (bvh.tree.levels - (level - 1)) + virtual_nodes_before = 2 * virtual_nodes_level - count_ones(virtual_nodes_level) + + # Split computation into contiguous ranges of minimum 100 elements each; if only single thread + # is needed, inline call + tp = TaskPartitioner(num_src, options.num_threads, options.min_traversals_per_thread) + if tp.num_tasks == 1 + num_dst = traverse_rays_nodes_range!( + bvh, points, directions, + src, dst, nothing, + virtual_nodes_before, + (1, num_src), + ) + else + # Keep track of tasks launched and number of elements written by each task in their unique + # memory region. The unique region is equal to 2 dst elements per src element + tasks = Vector{Task}(undef, tp.num_tasks) + num_written = Vector{Int}(undef, tp.num_tasks) + @inbounds for i in 1:tp.num_tasks + istart, iend = tp[i] + tasks[i] = Threads.@spawn traverse_rays_nodes_range!( + bvh, points, directions, + src, view(dst, 2istart - 1:2iend), view(num_written, i), + virtual_nodes_before, + (istart, iend), + ) + end + + # As tasks finish sequentially, move the new written intersctions into contiguous region + num_dst = 0 + @inbounds for i in 1:tp.num_tasks + wait(tasks[i]) + task_num_written = num_written[i] + + # Repack written contacts by the second, third thread, etc. + if i > 1 + istart, iend = tp[i] + for j in 1:task_num_written + dst[num_dst + j] = dst[2istart - 1 + j - 1] + end + end + num_dst += task_num_written + end + end + + num_dst +end + + +function traverse_rays_nodes_range!( + bvh, points, directions, src, dst, num_written, num_skips, irange, +) + # Check src[irange[1]:irange[2]] and write to dst[1:num_dst]; dst should be given as a view + num_dst = 0 + + # For each BVTT node-ray pair, check for intersection + @inbounds for i in irange[1]:irange[2] + # Extract implicit indices of BVH nodes to test + implicit, iray = src[i] + node = bvh.nodes[implicit - num_skips] + + # Extract ray + p = @view points[:, iray] + d = @view directions[:, iray] + + # If the node and ray is touching, expand BVTT with new possible contacts - i.e. pair + if isintersection(node, p, d) + # If a node's right child is virtual, don't add that check. Guaranteed to always have + # at least one real child + + # BVH node's right child is virtual + if isvirtual(bvh.tree, 2 * implicit + 1) + dst[num_dst + 1] = (implicit * 2, iray) + num_dst += 1 + else + dst[num_dst + 1] = (implicit * 2, iray) + dst[num_dst + 2] = (implicit * 2 + 1, iray) + num_dst += 2 + end + end + end + + # Known at compile-time; no return if called in multithreaded context + if isnothing(num_written) + return num_dst + else + num_written[] = num_dst + return nothing + end +end + + +function traverse_rays_leaves!(bvh, points, directions, src, intersections, num_src, ::Nothing, options) + # Traverse final level, only doing ray-leaf checks + + # Split computation into contiguous ranges of minimum 100 elements each; if only single thread + # is needed, inline call + tp = TaskPartitioner(num_src, options.num_threads, options.min_traversals_per_thread) + if tp.num_tasks == 1 + num_intersections = traverse_rays_leaves_range!( + bvh, points, directions, + src, intersections, nothing, + (1, num_src), + ) + else + num_intersections = 0 + + # Keep track of tasks launched and number of elements written by each task in their unique + # memory region. The unique region is equal to 1 dst elements per src element + tasks = Vector{Task}(undef, tp.num_tasks) + num_written = Vector{Int}(undef, tp.num_tasks) + @inbounds for i in 1:tp.num_tasks + istart, iend = tp[i] + tasks[i] = Threads.@spawn traverse_rays_leaves_range!( + bvh, points, directions, + src, view(intersections, istart:iend), view(num_written, i), + (istart, iend), + ) + end + @inbounds for i in 1:tp.num_tasks + wait(tasks[i]) + task_num_written = num_written[i] + + # Repack written contacts by the second, third thread, etc. + if i > 1 + istart, iend = tp[i] + for j in 1:task_num_written + intersections[num_intersections + j] = intersections[istart + j - 1] + end + end + num_intersections += task_num_written + end + end + + num_intersections +end + + +function traverse_rays_leaves_range!( + bvh, points, directions, src, intersections, num_written, irange +) + # Check src[irange[1]:irange[2]] and write to dst[1:num_dst]; dst should be given as a view + num_dst = 0 + + # Number of implicit indices above leaf-level + num_above = pow2(bvh.tree.levels - 1) - 1 + + # For each BVTT node-ray pair, check for intersection + @inbounds for i in irange[1]:irange[2] + # Extract implicit indices of BVH leaves to test + implicit, iray = src[i] + + iorder = bvh.order[implicit - num_above] + leaf = bvh.leaves[iorder] + + p = @view points[:, iray] + d = @view directions[:, iray] + + # If leaf-ray intersection, save in intersections + if isintersection(leaf, p, d) + intersections[num_dst + 1] = (iorder, iray) + num_dst += 1 + end + end + + # Known at compile-time; no return if called in multithreaded context + if isnothing(num_written) + return num_dst + else + num_written[] = num_dst + return nothing + end +end \ No newline at end of file diff --git a/src/raytrace/raytrace_gpu.jl b/src/raytrace/raytrace_gpu.jl new file mode 100644 index 0000000..09253f1 --- /dev/null +++ b/src/raytrace/raytrace_gpu.jl @@ -0,0 +1,172 @@ +function traverse_rays_nodes!(bvh, points, directions, src, dst, num_src, dst_offsets, level, options) + # Traverse nodes when level is above leaves + + # Compute number of virtual elements before this level to skip when computing the memory index + virtual_nodes_level = bvh.tree.virtual_leaves >> (bvh.tree.levels - (level - 1)) + virtual_nodes_before = 2 * virtual_nodes_level - count_ones(virtual_nodes_level) + + block_size = options.block_size + backend = get_backend(src) + + kernel! = _traverse_rays_nodes_gpu!(backend, block_size) + kernel!( + bvh.tree, bvh.nodes, points, directions, + src, dst, num_src, dst_offsets, + virtual_nodes_before, + ) + +end + +@kernel cpu=false inbounds=true function _traverse_rays_nodes_gpu!( + tree, nodes, points, directions, + src, dst, num_src, dst_offsets, + num_skips, +) + # Group (block) and local (thread) indices + iblock = @index(Group, Linear) + ithread = @index(Local, Linear) + + block_size = @groupsize()[1] + + # At most 2N sprouted checks from N src + temp = @localmem eltype(dst) (2 * block_size,) + temp_offset = @localmem typeof(iblock) (1,) + block_dst_offset = @localmem typeof(iblock) (1,) + + # Write the initial offset for this block as zero. This will be atomically incremented as new + # pairs are written to temp + if ithread == 1 + temp_offset[1] = 0 + end + @synchronize() + + index = ithread + (iblock - 1) * block_size + if index <= num_src + + # Extract implicit indices of BVH nodes and rays to test + implicit, iray = src[index] + + node = nodes[implicit - num_skips] + p = @view points[:, iray] + d = @view directions[:, iray] + + # If a ray and node are touching, expand BVTT with new possible intersections - i.e. pair + # the nodes' children with the ray + + if isintersection(node, p, d) + # If a node's right child is virtual, don't add that check. Guaranteed to always have + # at least one real child + + # BVH node's right child is virtual + if unsafe_isvirtual(tree, 2 * implicit + 1) + new_temp_offset = @atomic temp_offset[1] += 1 + temp[new_temp_offset - 1 + 1] = (implicit * 2, iray) + # BVH node's right child is real + else + new_temp_offset = @atomic temp_offset[1] += 2 + temp[new_temp_offset - 2 + 1] = (implicit * 2, iray) + temp[new_temp_offset - 2 + 2] = (implicit * 2 + 1, iray) + end + end + end + @synchronize() + + # Now we have to move the indices from temp to dst in chunks, at offsets reserved via atomic + # incrementing of dst_offset + num_temp = temp_offset[1] # Number of indices to write + if ithread == 1 + block_dst_offset[1] = @atomic dst_offsets[end] += num_temp + end + + @synchronize() + offset = block_dst_offset[1] - num_temp + if i <= num_temp + dst[offset + ithread] = temp[ithread] + end +end + + +function traverse_rays_leaves!( + bvh, points, directions, + src::AbstractGPUVector, intersections::AbstractGPUVector, + num_src, dst_offsets, options +) + + # Traverse final level, only doing ray-leaf checks + num_above = pow2(bvh.tree.levels - 1) - 1 + + block_size = options.block_size + num_blocks = (num_src + block_size - 1) ÷ block_size + backend = get_backend(src) + + kernel! = _traverse_rays_leaves!(backend, block_size) + kernel!( + bvh.leaves, bvh.order, points, directions, + src, intersections, num_src, dst_offsets, + num_above, ndrange=num_blocks * block_size, + ) + + # We need to know how many checks we have written into dst + synchronize(backend) + @allowscalar dst_offsets[end] +end + + +@kernel cpu=false inbounds=true function _traverse_rays_leaves_gpu!( + leaves, order, points, directions, + src, dst, + num_src, dst_offsets, + num_above, +) + # Group (block) and local (thread) indices + iblock = @index(Group, Linear) + ithread = @index(Local, Linear) + + block_size = @groupsize()[1] + + # At most N sprouted checks from N src + temp = @localmem eltype(dst) (block_size,) + temp_offset = @localmem typeof(iblock) (1,) + block_dst_offset = @localmem typeof(iblock) (1,) + + # Write the initial offset for this block as zero. This will be atomically incremented as new + # pairs are written to temp + if ithread == 1 + temp_offset[1] = 0 + end + @synchronize() + + # For each BVTT node-ray pair, check for intersection + index = ithread + (iblock - 1) * block_size + if index <= num_src + + # Extract implicit indices of BVH leaves to test + implicit, iray = src[index] + iorder = order[implicit - num_above] + + leaf = leaves[iorder] + p = @view points[:, iray] + d = @view directions[:, iray] + + # If leaf-ray intersects save the intersection + if isintersection(leaf, p, d) + new_temp_offset = @atomic temp_offset[1] += 1 + temp[new_temp_offset - 1 + 1] = (iorder, iray) + end + end + @synchronize() + + # Now we have to move the indices from temp to dst in chunks, at offsets reserved via atomic + # incrementing of dst_offset + num_temp = temp_offset[1] # Number of indices to write + + if ithread == 1 + block_dst_offset[1] = @atomic dst_offsets[end] += num_temp + end + @synchronize() + + offset = block_dst_offset[1] - num_temp + if ithread <= num_temp + dst[offset + ithread] = temp[ithread] + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index dcc3cb9..6990d45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -238,6 +238,116 @@ end end +@testset "ray-box isintersection" begin + + # Below box and ray going through corner + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [-1., -1., -1.] + direction = [1., 1., 1.] + @test isintersection(box, point, direction) == true + + # Below box and ray going through corner ray direction flipped case + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [-1., -1., -1.] + direction = [-1., -1., -1.] + @test isintersection(box, point, direction) == false + + # Below box ray going up and through face + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [-1., -.5, 0.] + direction = [5., 3., 1.5] + @test isintersection(box, point, direction) == true + + # Below box ray going up and through face + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [0.5, -0.5, 0.5] + direction = [0., 1., 0.] + @test isintersection(box, point, direction) == true + + # Below box ray going up and through face ray direction flipped case + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [-1., -.5, 0.] + direction = [-5., -3., -1.5] + @test isintersection(box, point, direction) == false + + # Inside box going through upper corner case + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [.5, .5, .5] + direction = [1., 1., 1.] + @test isintersection(box, point, direction) == true + + # Inside box going through bottom corner (direction flipped case) + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [.5, .5, .5] + direction = [-1., -1., -1.] + @test isintersection(box, point, direction) == true + + # Inside box going along face surface + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [1e-8, 0, 0.5] + direction = [0, 1., 0] + @test isintersection(box, point, direction) == true + + # Outside box going along edge + box = BBox((0., 0., 0.), (1., 1., 1.)) + point = [1e-8, -1., 1e-8] + direction = [0, 1., 0] + @test isintersection(box, point, direction) == true +end + + +@testset "ray-sphere isintersection" begin + + # ray above sphere passing down and through + sphere = BSphere((0., 0., 0.), 0.5) + point = [.5, .5, .5] + direction = [-1., -1., -1.] + @test isintersection(sphere, point, direction) == true + + # ray above sphere passing up and not intersecting direction flipped + sphere = BSphere((0., 0., 0.), 0.5) + point = [.5, .5, .5] + direction = [1., 1., 1.] + @test isintersection(sphere, point, direction) == false + + # ray below sphere passing up and intersecting + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., 0., -1.] + direction = [0., 0., 1.] + @test isintersection(sphere, point, direction) == true + + # ray below sphere passing and not intersecting + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., 0., -1.] + direction = [0., 0., -1.] + @test isintersection(sphere, point, direction) == false + + # ray below sphere passing up and tangent to sphere + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., 0.5, -1.] + direction = [0., 0., 1.] + @test isintersection(sphere, point, direction) == true + + # ray to the side of sphere and passing tangent to sphere + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., -1, 0.5] + direction = [0., 1., 0.] + @test isintersection(sphere, point, direction) == true + + # ray inside sphere passing up and out + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., 0., 0.] + direction = [0., 0., 1.] + @test isintersection(sphere, point, direction) == true + + # ray inside sphere passing down and out flipped direction + sphere = BSphere((0., 0., 0.), 0.5) + point = [0., 0., 0.] + direction = [0., 0., -1.] + @test isintersection(sphere, point, direction) == true + +end + @testset "test_morton" begin From 0193c62bb68f880d249de0288171937318d0972a Mon Sep 17 00:00:00 2001 From: Jack Grogan Date: Fri, 1 Nov 2024 16:49:39 +0000 Subject: [PATCH 2/5] Implementation of raytracing in implicitBVH.jl --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 6990d45..6eceb4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -240,6 +240,8 @@ end @testset "ray-box isintersection" begin + using ImplicitBVH: isintersection + # Below box and ray going through corner box = BBox((0., 0., 0.), (1., 1., 1.)) point = [-1., -1., -1.] From 79bb36a71b7cf0208f5ad8b506423fcea177bbb8 Mon Sep 17 00:00:00 2001 From: Jack Grogan Date: Fri, 1 Nov 2024 17:06:45 +0000 Subject: [PATCH 3/5] Implementation of raytracing in implicitBVH.jl --- docs/src/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/index.md b/docs/src/index.md index 5a2bb97..27f3c4b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,6 +5,7 @@ ```@docs BVH traverse +traverse_rays BVHTraversal default_start_level ImplicitBVH.IndexPair From 943477590f9786b0e621b39319276afd8bbc2954 Mon Sep 17 00:00:00 2001 From: Jack Grogan Date: Mon, 4 Nov 2024 15:51:25 +0000 Subject: [PATCH 4/5] Implementation of raytracing in implicitBVH.jl --- src/ImplicitBVH.jl | 1 - src/bounding_volumes.jl | 126 +++++++++++++++++++++++++++++++++++ src/rays.jl | 81 ---------------------- src/raytrace/raytrace_gpu.jl | 8 ++- 4 files changed, 132 insertions(+), 84 deletions(-) delete mode 100644 src/rays.jl diff --git a/src/ImplicitBVH.jl b/src/ImplicitBVH.jl index a6412be..432d82a 100644 --- a/src/ImplicitBVH.jl +++ b/src/ImplicitBVH.jl @@ -28,7 +28,6 @@ include("utils.jl") include("morton.jl") include("implicit_tree.jl") include("bounding_volumes.jl") -include("rays.jl") include("build.jl") include("traverse/traverse.jl") include("raytrace/raytrace.jl") diff --git a/src/bounding_volumes.jl b/src/bounding_volumes.jl index 9ab5874..571a462 100644 --- a/src/bounding_volumes.jl +++ b/src/bounding_volumes.jl @@ -8,12 +8,66 @@ Check if two bounding volumes are touching or inter-penetrating. """ function iscontact end +""" + isintersection(b::BBox, p::Type{3, T}, d::Type{3, T}) + isintersection(s::BSphere, p::Type{3, T}, d::Type{3, T}) + +Return True if ray intersects a sphere or box +""" + +# will go into bounding volumes +function isintersection end + +""" +# Examples + +Simple ray bounding box intersection example: + +```jldoctest +using ImplicitBVH +using ImplicitBVH: BSphere, BBox, isintersection + +# Generate a simple bounding box + +bounding_box = BBox((0., 0., 0.), (1., 1., 1.)) + +# Generate a ray passing up and through the bottom face of the bounding box + +point = [.5, .5, -10] +direction = [0, 0, 1] +isintersection(bounding_box, point, direction) + +# output +true +``` + +Simple ray bounding sphere intersection example: + +```jldoctest +using ImplicitBVH +using ImplicitBVH: BSphere, BBox, isintersection + +# Generate a simple bounding sphere + +bounding_sphere = BSphere((0., 0., 0.), 0.5) + +# Generate a ray passing up and through the bounding sphere + +point = [0, 0, -10] +direction = [0, 0, 1] +isintersection(bounding_sphere, point, direction) + +# output +true +``` +""" """ center(b::BSphere) center(b::BBox{T}) where T Get the coordinates of a bounding volume's centre, as a NTuple{3, T}. + """ function center end @@ -408,3 +462,75 @@ end function iscontact(a::BBox, b::BSphere) iscontact(b, a) end + + +@inline function isintersection(b::BBox, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + T = eltype(d) + + @inbounds begin + inv_d = (one(T) / d[1], one(T) / d[2], one(T) / d[3]) + + # Set x bounds + t_bound_x1 = (b.lo[1] - p[1]) * inv_d[1] + t_bound_x2 = (b.up[1] - p[1]) * inv_d[1] + + tmin = minimum2(t_bound_x1, t_bound_x2) + tmax = maximum2(t_bound_x1, t_bound_x2) + + # Set y bounds + t_bound_y1 = (b.lo[2] - p[2]) * inv_d[2] + t_bound_y2 = (b.up[2] - p[2]) * inv_d[2] + + tmin = maximum2(tmin, minimum2(t_bound_y1, t_bound_y2)) + tmax = minimum2(tmax, maximum2(t_bound_y1, t_bound_y2)) + + # Set z bounds + t_bound_z1 = (b.lo[3] - p[3]) * inv_d[3] + t_bound_z2 = (b.up[3] - p[3]) * inv_d[3] + + tmin = maximum2(tmin, minimum2(t_bound_z1, t_bound_z2)) + tmax = minimum2(tmax, maximum2(t_bound_z1, t_bound_z2)) + end + + # If condition satisfied ray intersects box. tmax >= 0 + # ensure only forwards intersections are counted + (tmin <= tmax) && (tmax >= 0) +end + + +@inline function isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + @inbounds begin + a = dot3(d, d) + b = 2 * ( + (p[1] - s.x[1]) * d[1] + + (p[2] - s.x[2]) * d[2] + + (p[3] - s.x[3]) * d[3] + ) + c = ( + (p[1] - s.x[1]) * (p[1] - s.x[1]) + + (p[2] - s.x[2]) * (p[2] - s.x[2]) + + (p[3] - s.x[3]) * (p[3] - s.x[3]) + ) - s.r * s.r + end + + discriminant = b * b - 4 * a * c + + if discriminant >= 0 + # Ensure only forwards intersections are counted + return 0 >= b * c + else + return false + end +end \ No newline at end of file diff --git a/src/rays.jl b/src/rays.jl deleted file mode 100644 index 5bf683b..0000000 --- a/src/rays.jl +++ /dev/null @@ -1,81 +0,0 @@ -""" - isintersection(b::BBox, p::Type{3, T}, d::Type{3, T}) - isintersection(s::BSphere, p::Type{3, T}, d::Type{3, T}) - -Return True if ray intersects a sphere or box -""" - -# will go into bounding volumes -function isintersection end - - -@inline function isintersection(b::BBox, p::AbstractVector, d::AbstractVector) - - @boundscheck begin - @assert length(p) == 3 - @assert length(d) == 3 - end - - T = eltype(d) - - @inbounds begin - inv_d = (one(T) / d[1], one(T) / d[2], one(T) / d[3]) - - # Set x bounds - t_bound_x1 = (b.lo[1] - p[1]) * inv_d[1] - t_bound_x2 = (b.up[1] - p[1]) * inv_d[1] - - tmin = minimum2(t_bound_x1, t_bound_x2) - tmax = maximum2(t_bound_x1, t_bound_x2) - - # Set y bounds - t_bound_y1 = (b.lo[2] - p[2]) * inv_d[2] - t_bound_y2 = (b.up[2] - p[2]) * inv_d[2] - - tmin = maximum2(tmin, minimum2(t_bound_y1, t_bound_y2)) - tmax = minimum2(tmax, maximum2(t_bound_y1, t_bound_y2)) - - # Set z bounds - t_bound_z1 = (b.lo[3] - p[3]) * inv_d[3] - t_bound_z2 = (b.up[3] - p[3]) * inv_d[3] - - tmin = maximum2(tmin, minimum2(t_bound_z1, t_bound_z2)) - tmax = minimum2(tmax, maximum2(t_bound_z1, t_bound_z2)) - end - - # If condition satisfied ray intersects box. tmax >= 0 - # ensure only forwards intersections are counted - (tmin <= tmax) && (tmax >= 0) -end - - -@inline function isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) - - @boundscheck begin - @assert length(p) == 3 - @assert length(d) == 3 - end - - @inbounds begin - a = dot3(d, d) - b = 2 * ( - (p[1] - s.x[1]) * d[1] + - (p[2] - s.x[2]) * d[2] + - (p[3] - s.x[3]) * d[3] - ) - c = ( - (p[1] - s.x[1]) * (p[1] - s.x[1]) + - (p[2] - s.x[2]) * (p[2] - s.x[2]) + - (p[3] - s.x[3]) * (p[3] - s.x[3]) - ) - s.r * s.r - end - - discriminant = b * b - 4 * a * c - - if discriminant >= 0 - # Ensure only forwards intersections are counted - return 0 >= b * c - else - return false - end -end \ No newline at end of file diff --git a/src/raytrace/raytrace_gpu.jl b/src/raytrace/raytrace_gpu.jl index 09253f1..55fa5a1 100644 --- a/src/raytrace/raytrace_gpu.jl +++ b/src/raytrace/raytrace_gpu.jl @@ -6,6 +6,7 @@ function traverse_rays_nodes!(bvh, points, directions, src, dst, num_src, dst_of virtual_nodes_before = 2 * virtual_nodes_level - count_ones(virtual_nodes_level) block_size = options.block_size + num_blocks = (num_src + block_size - 1) ÷ block_size backend = get_backend(src) kernel! = _traverse_rays_nodes_gpu!(backend, block_size) @@ -13,10 +14,14 @@ function traverse_rays_nodes!(bvh, points, directions, src, dst, num_src, dst_of bvh.tree, bvh.nodes, points, directions, src, dst, num_src, dst_offsets, virtual_nodes_before, + ndrange=num_blocks * block_size, ) - + + # We need to know how many checks we have written into dst + @allowscalar dst_offsets[level] end + @kernel cpu=false inbounds=true function _traverse_rays_nodes_gpu!( tree, nodes, points, directions, src, dst, num_src, dst_offsets, @@ -107,7 +112,6 @@ function traverse_rays_leaves!( ) # We need to know how many checks we have written into dst - synchronize(backend) @allowscalar dst_offsets[end] end From bbcce31f35f6e6c5edd0bee3a9f835add81d60cf Mon Sep 17 00:00:00 2001 From: Jack Grogan Date: Mon, 4 Nov 2024 18:40:08 +0000 Subject: [PATCH 5/5] Implementation of raytracing in implicitBVH.jl --- README.md | 49 ++- docs/src/bounding_volumes.md | 1 + prototype/raytracing.jl | 3 - src/ImplicitBVH.jl | 3 +- src/bounding_volumes.jl | 536 ----------------------- src/bounding_volumes/bbox.jl | 113 +++++ src/bounding_volumes/bounding_volumes.jl | 81 ++++ src/bounding_volumes/bsphere.jl | 146 ++++++ src/bounding_volumes/iscontact.jl | 28 ++ src/bounding_volumes/isintersection.jl | 72 +++ src/bounding_volumes/merge.jl | 85 ++++ 11 files changed, 575 insertions(+), 542 deletions(-) delete mode 100644 src/bounding_volumes.jl create mode 100644 src/bounding_volumes/bbox.jl create mode 100644 src/bounding_volumes/bounding_volumes.jl create mode 100644 src/bounding_volumes/bsphere.jl create mode 100644 src/bounding_volumes/iscontact.jl create mode 100644 src/bounding_volumes/isintersection.jl create mode 100644 src/bounding_volumes/merge.jl diff --git a/README.md b/README.md index 85eb4cc..6972a23 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ detection downwards from this level. ## Examples +### Multithreaded Contact Detection + Simple usage with bounding spheres and default 64-bit types: ```julia @@ -127,7 +129,7 @@ traversal = traverse( Check out the `benchmark` folder for an example traversing an STL model. -# GPU Bounding Volume Hierarchy Building and Traversal +### GPU-Accelerated Contact Detection Simply use a GPU array for the bounding volumes; the interface remains the same, and all operations - Morton encoding, sorting, BVH building and traversal for contact finding - will run on the right backend: @@ -155,6 +157,51 @@ traversal = traverse(bvh) ``` +### Multithreaded Ray Tracing + +Using `BSphere{Float32}` for leaves, `BBox{Float32}` for merged nodes above, and `UInt32` Morton codes: + +```julia +using ImplicitBVH +using ImplicitBVH: BBox, BSphere + +# Load mesh and compute bounding spheres for each triangle. Can download mesh from: +# https://github.com/alecjacobson/common-3d-test-models/blob/master/data/xyzrgb_dragon.obj +using MeshIO +using FileIO + +mesh = load("xyzrgb_dragon.obj") + +# Generate bounding spheres around each triangle in the mesh +bounding_spheres = [BSphere{Float32}(tri) for tri in mesh] + +# Build BVH +bvh = BVH(bounding_spheres, BBox{Float32}, UInt32) + +# Generate some rays +points = rand(Float32, 3, 1000) +directions = rand(Float32, 3, 1000) + +# Traverse BVH to get indices of rays intersecting the bounding spheres +traversal = traverse_rays(bvh, points, directions) +@show traversal.contacts + +# output +traversal.contacts = Tuple{Int32, Int32}[...] +``` + +The bounding spheres around each triangle can be computed in parallel (including on GPUs) using [AcceleratedKernels.jl](https://github.com/anicusan/AcceleratedKernels.jl): + +```julia +import AcceleratedKernels as AK + +bounding_spheres = Vector{BSphere{Float32}}(undef, length(mesh)) +AK.map!(BSphere{Float32}, bounding_spheres, mesh) +``` + +For GPUs simply swap `Vector` with `ROCVector`, `MtlVector`, `oneVector` or `CuVector`, and AcceleratedKernels will automatically run the code on the right GPU backend (from `AMDGPU`, `Metal`, `oneAPI`, `CUDA`). + + # Implicit Bounding Volume Hierarchy The main idea behind the ImplicitBVH is the use of an implicit perfect binary tree constructed from some diff --git a/docs/src/bounding_volumes.md b/docs/src/bounding_volumes.md index 602d593..dddf628 100644 --- a/docs/src/bounding_volumes.md +++ b/docs/src/bounding_volumes.md @@ -9,6 +9,7 @@ ImplicitBVH.BSphere ```@docs ImplicitBVH.iscontact +ImplicitBVH.isintersection ImplicitBVH.center ``` diff --git a/prototype/raytracing.jl b/prototype/raytracing.jl index cc746b7..d674de7 100644 --- a/prototype/raytracing.jl +++ b/prototype/raytracing.jl @@ -40,10 +40,7 @@ for (ibv, iray) in bvt.contacts end -# TODO Brute force tests for comparison -# TODO document rays and integrate them in bounding_volumes.jl # TODO add example to README -# TODO include in docs/ # TODO maybe a pretty image / render? # TODO benchmark against a standard raytracer? Check Chitalu's paper; the Stanford bunny example diff --git a/src/ImplicitBVH.jl b/src/ImplicitBVH.jl index 432d82a..9827154 100644 --- a/src/ImplicitBVH.jl +++ b/src/ImplicitBVH.jl @@ -23,11 +23,10 @@ using GPUArraysCore: AbstractGPUVector, @allowscalar import AcceleratedKernels as AK -# Include code from other files include("utils.jl") include("morton.jl") include("implicit_tree.jl") -include("bounding_volumes.jl") +include("bounding_volumes/bounding_volumes.jl") include("build.jl") include("traverse/traverse.jl") include("raytrace/raytrace.jl") diff --git a/src/bounding_volumes.jl b/src/bounding_volumes.jl deleted file mode 100644 index 571a462..0000000 --- a/src/bounding_volumes.jl +++ /dev/null @@ -1,536 +0,0 @@ -""" - iscontact(a::BSphere, b::BSphere) - iscontact(a::BBox, b::BBox) - iscontact(a::BSphere, b::BBox) - iscontact(a::BBox, b::BSphere) - -Check if two bounding volumes are touching or inter-penetrating. -""" -function iscontact end - -""" - isintersection(b::BBox, p::Type{3, T}, d::Type{3, T}) - isintersection(s::BSphere, p::Type{3, T}, d::Type{3, T}) - -Return True if ray intersects a sphere or box -""" - -# will go into bounding volumes -function isintersection end - -""" -# Examples - -Simple ray bounding box intersection example: - -```jldoctest -using ImplicitBVH -using ImplicitBVH: BSphere, BBox, isintersection - -# Generate a simple bounding box - -bounding_box = BBox((0., 0., 0.), (1., 1., 1.)) - -# Generate a ray passing up and through the bottom face of the bounding box - -point = [.5, .5, -10] -direction = [0, 0, 1] -isintersection(bounding_box, point, direction) - -# output -true -``` - -Simple ray bounding sphere intersection example: - -```jldoctest -using ImplicitBVH -using ImplicitBVH: BSphere, BBox, isintersection - -# Generate a simple bounding sphere - -bounding_sphere = BSphere((0., 0., 0.), 0.5) - -# Generate a ray passing up and through the bounding sphere - -point = [0, 0, -10] -direction = [0, 0, 1] -isintersection(bounding_sphere, point, direction) - -# output -true -``` -""" - -""" - center(b::BSphere) - center(b::BBox{T}) where T - -Get the coordinates of a bounding volume's centre, as a NTuple{3, T}. - -""" -function center end - - -""" - translate(b::BSphere{T}, dx) where T - translate(b::BBox{T}, dx) where T - -Get a new bounding volume translated by dx; dx can be any iterable with 3 elements. -""" -function translate end - - -""" - $(TYPEDEF) - -Bounding sphere, highly optimised for computing bounding volumes for triangles and merging into -larger bounding volumes. - -# Methods - # Convenience constructors - BSphere(x::NTuple{3, T}, r) - BSphere{T}(x::AbstractVector, r) where T - BSphere(x::AbstractVector, r) - - # Construct from triangle vertices - BSphere{T}(p1, p2, p3) where T - BSphere(p1, p2, p3) - BSphere{T}(vertices::AbstractMatrix) where T - BSphere(vertices::AbstractMatrix) - BSphere{T}(triangle) where T - BSphere(triangle) - - # Merging bounding volumes - BSphere{T}(a::BSphere, b::BSphere) where T - BSphere(a::BSphere{T}, b::BSphere{T}) where T - Base.:+(a::BSphere, b::BSphere) -""" -struct BSphere{T} - x::NTuple{3, T} - r::T -end - -Base.eltype(::BSphere{T}) where T = T -Base.eltype(::Type{BSphere{T}}) where T = T - - -# Convenience constructors, with and without type parameter -BSphere{T}(x::AbstractVector, r) where T = BSphere(NTuple{3, T}(x), T(r)) -BSphere(x::AbstractVector, r) = BSphere{eltype(x)}(x, r) - - -# Constructors from triangles -function BSphere{T}(p1, p2, p3) where T - - # Adapted from https://realtimecollisiondetection.net/blog/?p=20 - a = (T(p1[1]), T(p1[2]), T(p1[3])) - b = (T(p2[1]), T(p2[2]), T(p2[3])) - c = (T(p3[1]), T(p3[2]), T(p3[3])) - - # Unrolled dot(b - a, b - a) - abab = (b[1] - a[1]) * (b[1] - a[1]) + - (b[2] - a[2]) * (b[2] - a[2]) + - (b[3] - a[3]) * (b[3] - a[3]) - - # Unrolled dot(b - a, c - a) - abac = (b[1] - a[1]) * (c[1] - a[1]) + - (b[2] - a[2]) * (c[2] - a[2]) + - (b[3] - a[3]) * (c[3] - a[3]) - - # Unrolled dot(c - a, c - a) - acac = (c[1] - a[1]) * (c[1] - a[1]) + - (c[2] - a[2]) * (c[2] - a[2]) + - (c[3] - a[3]) * (c[3] - a[3]) - - d = T(2.) * (abab * acac - abac * abac) - - if abs(d) <= eps(T) - # a, b, c lie on a line. Find line centre and radius - lower = (minimum3(a[1], b[1], c[1]), - minimum3(a[2], b[2], c[2]), - minimum3(a[3], b[3], c[3])) - - upper = (maximum3(a[1], b[1], c[1]), - maximum3(a[2], b[2], c[2]), - maximum3(a[3], b[3], c[3])) - - centre = (T(0.5) * (lower[1] + upper[1]), - T(0.5) * (lower[2] + upper[2]), - T(0.5) * (lower[3] + upper[3])) - radius = dist3(centre, upper) - else - s = (abab * acac - acac * abac) / d - t = (acac * abab - abab * abac) / d - - if s <= zero(T) - centre = (T(0.5) * (a[1] + c[1]), - T(0.5) * (a[2] + c[2]), - T(0.5) * (a[3] + c[3])) - radius = dist3(centre, a) - elseif t <= zero(T) - centre = (T(0.5) * (a[1] + b[1]), - T(0.5) * (a[2] + b[2]), - T(0.5) * (a[3] + b[3])) - radius = dist3(centre, a) - elseif s + t >= one(T) - centre = (T(0.5) * (b[1] + c[1]), - T(0.5) * (b[2] + c[2]), - T(0.5) * (b[3] + c[3])) - radius = dist3(centre, b) - else - centre = (a[1] + s * (b[1] - a[1]) + t * (c[1] - a[1]), - a[2] + s * (b[2] - a[2]) + t * (c[2] - a[2]), - a[3] + s * (b[3] - a[3]) + t * (c[3] - a[3])) - radius = dist3(centre, a) - end - end - - BSphere(centre, radius) -end - - -# Convenience constructors, with and without explicit type parameter -function BSphere(p1, p2, p3) - BSphere{eltype(p1)}(p1, p2, p3) -end - -function BSphere{T}(triangle) where T - # Decompose triangle into its 3 vertices. - # Works transparently with GeometryBasics.Triangle, Vector{SVector{3, T}}, etc. - p1, p2, p3 = triangle - BSphere{T}(p1, p2, p3) -end - -function BSphere(triangle) - p1, p2, p3 = triangle - BSphere{eltype(p1)}(p1, p2, p3) -end - -function BSphere{T}(vertices::AbstractMatrix) where T - BSphere{T}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) -end - -function BSphere(vertices::AbstractMatrix) - BSphere{eltype(vertices)}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) -end - - -# Overloaded center function -center(b::BSphere) = b.x - - -# Overloaded translate function -function translate(b::BSphere{T}, dx) where T - new_center = (b.x[1] + T(dx[1]), - b.x[2] + T(dx[2]), - b.x[3] + T(dx[3])) - BSphere{T}(new_center, b.r) -end - - -# Merge two bounding spheres -function BSphere{T}(a::BSphere, b::BSphere) where T - length = dist3(a.x, b.x) - - # a is enclosed within b - if length + a.r <= b.r - return BSphere{T}(b.x, b.r) - - # b is enclosed within a - elseif length + b.r <= a.r - return BSphere{T}(a.x, a.r) - - # Bounding spheres are not enclosed - else - frac = T(0.5) * ((b.r - a.r) / length + T(1)) - centre = (a.x[1] + frac * (b.x[1] - a.x[1]), - a.x[2] + frac * (b.x[2] - a.x[2]), - a.x[3] + frac * (b.x[3] - a.x[3])) - radius = T(0.5) * (length + a.r + b.r) - return BSphere{T}(centre, radius) - end -end - - -BSphere(a::BSphere{T}, b::BSphere{T}) where T = BSphere{T}(a, b) -Base.:+(a::BSphere, b::BSphere) = BSphere(a, b) - - -# Contact detection -function iscontact(a::BSphere, b::BSphere) - dist3sq(a.x, b.x) <= (a.r + b.r) * (a.r + b.r) -end - - - - -""" - $(TYPEDEF) - -Axis-aligned bounding box, highly optimised for computing bounding volumes for triangles and -merging into larger bounding volumes. - -Can also be constructed from two spheres to e.g. allow merging [`BSphere`](@ref) leaves into -[`BBox`](@ref) nodes. - -# Methods - # Convenience constructors - BBox(lo::NTuple{3, T}, up::NTuple{3, T}) where T - BBox{T}(lo::AbstractVector, up::AbstractVector) where T - BBox(lo::AbstractVector, up::AbstractVector) - - # Construct from triangle vertices - BBox{T}(p1, p2, p3) where T - BBox(p1, p2, p3) - BBox{T}(vertices::AbstractMatrix) where T - BBox(vertices::AbstractMatrix) - BBox{T}(triangle) where T - BBox(triangle) - - # Merging bounding boxes - BBox{T}(a::BBox, b::BBox) where T - BBox(a::BBox{T}, b::BBox{T}) where T - Base.:+(a::BBox, b::BBox) - - # Merging bounding spheres - BBox{T}(a::BSphere{T}) where T - BBox(a::BSphere{T}) where T - BBox{T}(a::BSphere{T}, b::BSphere{T}) where T - BBox(a::BSphere{T}, b::BSphere{T}) where T -""" -struct BBox{T} - lo::NTuple{3, T} - up::NTuple{3, T} -end - -Base.eltype(::BBox{T}) where T = T -Base.eltype(::Type{BBox{T}}) where T = T - - - -# Convenience constructors, with and without type parameter -function BBox{T}(lo::AbstractVector, up::AbstractVector) where T - BBox(NTuple{3, eltype(lo)}(lo), NTuple{3, eltype(up)}(up)) -end - -function BBox(lo::AbstractVector, up::AbstractVector) - BBox{eltype(lo)}(lo, up) -end - - - -# Constructors from triangles -function BBox{T}(p1, p2, p3) where T - - lower = (minimum3(p1[1], p2[1], p3[1]), - minimum3(p1[2], p2[2], p3[2]), - minimum3(p1[3], p2[3], p3[3])) - - upper = (maximum3(p1[1], p2[1], p3[1]), - maximum3(p1[2], p2[2], p3[2]), - maximum3(p1[3], p2[3], p3[3])) - - BBox{T}(lower, upper) -end - - -# Convenience constructors, with and without explicit type parameter -function BBox(p1, p2, p3) - BBox{eltype(p1)}(p1, p2, p3) -end - -function BBox{T}(triangle) where T - # Decompose triangle into its 3 vertices. - # Works transparently with GeometryBasics.Triangle, Vector{SVector{3, T}}, etc. - p1, p2, p3 = triangle - BBox{T}(p1, p2, p3) -end - -function BBox(triangle) - p1, p2, p3 = triangle - BBox{eltype(p1)}(p1, p2, p3) -end - -function BBox{T}(vertices::AbstractMatrix) where T - BBox{T}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) -end - -function BBox(vertices::AbstractMatrix) - BBox{eltype(vertices)}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) -end - - -# Overloaded center function -center(b::BBox{T}) where T = (T(0.5) * (b.lo[1] + b.up[1]), - T(0.5) * (b.lo[2] + b.up[2]), - T(0.5) * (b.lo[3] + b.up[3])) - - -# Overloaded translate function -function translate(b::BBox{T}, dx) where T - dx1, dx2, dx3 = T(dx[1]), T(dx[2]), T(dx[3]) - new_lo = (b.lo[1] + dx1, - b.lo[2] + dx2, - b.lo[3] + dx3) - new_up = (b.up[1] + dx1, - b.up[2] + dx2, - b.up[3] + dx3) - BBox{T}(new_lo, new_up) -end - - -# Merge two bounding boxes -function BBox{T}(a::BBox, b::BBox) where T - lower = (minimum2(a.lo[1], b.lo[1]), - minimum2(a.lo[2], b.lo[2]), - minimum2(a.lo[3], b.lo[3])) - - upper = (maximum2(a.up[1], b.up[1]), - maximum2(a.up[2], b.up[2]), - maximum2(a.up[3], b.up[3])) - - BBox{T}(lower, upper) -end - -BBox(a::BBox{T}, b::BBox{T}) where T = BBox{T}(a, b) -Base.:+(a::BBox, b::BBox) = BBox(a, b) - - -# Convert BSphere to BBox -function BBox{T}(a::BSphere{T}) where T - lower = (a.x[1] - a.r, a.x[2] - a.r, a.x[3] - a.r) - upper = (a.x[1] + a.r, a.x[2] + a.r, a.x[3] + a.r) - BBox(lower, upper) -end - -function BBox(a::BSphere{T}) where T - BBox{T}(a) -end - -# Merge two BSphere into enclosing BBox -function BBox{T}(a::BSphere{T}, b::BSphere{T}) where T - length = dist3(a.x, b.x) - - # a is enclosed within b - if length + a.r <= b.r - return BBox(b) - - # b is enclosed within a - elseif length + b.r <= a.r - return BBox(a) - - # Bounding spheres are not enclosed - else - lower = (minimum2(a.x[1] - a.r, b.x[1] - b.r), - minimum2(a.x[2] - a.r, b.x[2] - b.r), - minimum2(a.x[3] - a.r, b.x[3] - b.r)) - - upper = (maximum2(a.x[1] + a.r, b.x[1] + b.r), - maximum2(a.x[2] + a.r, b.x[2] + b.r), - maximum2(a.x[3] + a.r, b.x[3] + b.r)) - - return BBox(lower, upper) - end -end - -function BBox(a::BSphere{T}, b::BSphere{T}) where T - BBox{T}(a, b) -end - - -# Contact detection -function iscontact(a::BBox, b::BBox) - (a.up[1] >= b.lo[1] && a.lo[1] <= b.up[1]) && - (a.up[2] >= b.lo[2] && a.lo[2] <= b.up[2]) && - (a.up[3] >= b.lo[3] && a.lo[3] <= b.up[3]) -end - - -# Contact detection between heterogeneous BVs - only needed when one BVH has exactly one leaf -function iscontact(a::BSphere, b::BBox) - # This is an edge case, used for broad-phase collision detection, so we simply take the - # sphere's bounding box, as a full sphere-box contact detection is computationally heavy - ab = BBox( - (a.x[1] - a.r, a.x[2] - a.r, a.x[3] - a.r), - (a.x[1] + a.r, a.x[2] + a.r, a.x[3] + a.r), - ) - iscontact(ab, b) -end - - -function iscontact(a::BBox, b::BSphere) - iscontact(b, a) -end - - -@inline function isintersection(b::BBox, p::AbstractVector, d::AbstractVector) - - @boundscheck begin - @assert length(p) == 3 - @assert length(d) == 3 - end - - T = eltype(d) - - @inbounds begin - inv_d = (one(T) / d[1], one(T) / d[2], one(T) / d[3]) - - # Set x bounds - t_bound_x1 = (b.lo[1] - p[1]) * inv_d[1] - t_bound_x2 = (b.up[1] - p[1]) * inv_d[1] - - tmin = minimum2(t_bound_x1, t_bound_x2) - tmax = maximum2(t_bound_x1, t_bound_x2) - - # Set y bounds - t_bound_y1 = (b.lo[2] - p[2]) * inv_d[2] - t_bound_y2 = (b.up[2] - p[2]) * inv_d[2] - - tmin = maximum2(tmin, minimum2(t_bound_y1, t_bound_y2)) - tmax = minimum2(tmax, maximum2(t_bound_y1, t_bound_y2)) - - # Set z bounds - t_bound_z1 = (b.lo[3] - p[3]) * inv_d[3] - t_bound_z2 = (b.up[3] - p[3]) * inv_d[3] - - tmin = maximum2(tmin, minimum2(t_bound_z1, t_bound_z2)) - tmax = minimum2(tmax, maximum2(t_bound_z1, t_bound_z2)) - end - - # If condition satisfied ray intersects box. tmax >= 0 - # ensure only forwards intersections are counted - (tmin <= tmax) && (tmax >= 0) -end - - -@inline function isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) - - @boundscheck begin - @assert length(p) == 3 - @assert length(d) == 3 - end - - @inbounds begin - a = dot3(d, d) - b = 2 * ( - (p[1] - s.x[1]) * d[1] + - (p[2] - s.x[2]) * d[2] + - (p[3] - s.x[3]) * d[3] - ) - c = ( - (p[1] - s.x[1]) * (p[1] - s.x[1]) + - (p[2] - s.x[2]) * (p[2] - s.x[2]) + - (p[3] - s.x[3]) * (p[3] - s.x[3]) - ) - s.r * s.r - end - - discriminant = b * b - 4 * a * c - - if discriminant >= 0 - # Ensure only forwards intersections are counted - return 0 >= b * c - else - return false - end -end \ No newline at end of file diff --git a/src/bounding_volumes/bbox.jl b/src/bounding_volumes/bbox.jl new file mode 100644 index 0000000..e3e27c1 --- /dev/null +++ b/src/bounding_volumes/bbox.jl @@ -0,0 +1,113 @@ +""" + $(TYPEDEF) + +Axis-aligned bounding box, highly optimised for computing bounding volumes for triangles and +merging into larger bounding volumes. + +Can also be constructed from two spheres to e.g. allow merging [`BSphere`](@ref) leaves into +[`BBox`](@ref) nodes. + +# Methods + # Convenience constructors + BBox(lo::NTuple{3, T}, up::NTuple{3, T}) where T + BBox{T}(lo::AbstractVector, up::AbstractVector) where T + BBox(lo::AbstractVector, up::AbstractVector) + + # Construct from triangle vertices + BBox{T}(p1, p2, p3) where T + BBox(p1, p2, p3) + BBox{T}(vertices::AbstractMatrix) where T + BBox(vertices::AbstractMatrix) + BBox{T}(triangle) where T + BBox(triangle) + + # Merging bounding boxes + BBox{T}(a::BBox, b::BBox) where T + BBox(a::BBox{T}, b::BBox{T}) where T + Base.:+(a::BBox, b::BBox) + + # Merging bounding spheres + BBox{T}(a::BSphere{T}) where T + BBox(a::BSphere{T}) where T + BBox{T}(a::BSphere{T}, b::BSphere{T}) where T + BBox(a::BSphere{T}, b::BSphere{T}) where T +""" +struct BBox{T} + lo::NTuple{3, T} + up::NTuple{3, T} +end + +Base.eltype(::BBox{T}) where T = T +Base.eltype(::Type{BBox{T}}) where T = T + + + +# Convenience constructors, with and without type parameter +function BBox{T}(lo::AbstractVector, up::AbstractVector) where T + BBox(NTuple{3, eltype(lo)}(lo), NTuple{3, eltype(up)}(up)) +end + +function BBox(lo::AbstractVector, up::AbstractVector) + BBox{eltype(lo)}(lo, up) +end + + + +# Constructors from triangles +function BBox{T}(p1, p2, p3) where T + + lower = (minimum3(p1[1], p2[1], p3[1]), + minimum3(p1[2], p2[2], p3[2]), + minimum3(p1[3], p2[3], p3[3])) + + upper = (maximum3(p1[1], p2[1], p3[1]), + maximum3(p1[2], p2[2], p3[2]), + maximum3(p1[3], p2[3], p3[3])) + + BBox{T}(lower, upper) +end + + +# Convenience constructors, with and without explicit type parameter +function BBox(p1, p2, p3) + BBox{eltype(p1)}(p1, p2, p3) +end + +function BBox{T}(triangle) where T + # Decompose triangle into its 3 vertices. + # Works transparently with GeometryBasics.Triangle, Vector{SVector{3, T}}, etc. + p1, p2, p3 = triangle + BBox{T}(p1, p2, p3) +end + +function BBox(triangle) + p1, p2, p3 = triangle + BBox{eltype(p1)}(p1, p2, p3) +end + +function BBox{T}(vertices::AbstractMatrix) where T + BBox{T}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) +end + +function BBox(vertices::AbstractMatrix) + BBox{eltype(vertices)}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) +end + + +# Overloaded center function +center(b::BBox{T}) where T = (T(0.5) * (b.lo[1] + b.up[1]), + T(0.5) * (b.lo[2] + b.up[2]), + T(0.5) * (b.lo[3] + b.up[3])) + + +# Overloaded translate function +function translate(b::BBox{T}, dx) where T + dx1, dx2, dx3 = T(dx[1]), T(dx[2]), T(dx[3]) + new_lo = (b.lo[1] + dx1, + b.lo[2] + dx2, + b.lo[3] + dx3) + new_up = (b.up[1] + dx1, + b.up[2] + dx2, + b.up[3] + dx3) + BBox{T}(new_lo, new_up) +end diff --git a/src/bounding_volumes/bounding_volumes.jl b/src/bounding_volumes/bounding_volumes.jl new file mode 100644 index 0000000..de717dc --- /dev/null +++ b/src/bounding_volumes/bounding_volumes.jl @@ -0,0 +1,81 @@ +""" + iscontact(a::BSphere, b::BSphere) + iscontact(a::BBox, b::BBox) + iscontact(a::BSphere, b::BBox) + iscontact(a::BBox, b::BSphere) + +Check if two bounding volumes are touching or inter-penetrating. +""" +function iscontact end + + +""" + isintersection(b::BBox, p::AbstractVector, d::AbstractVector) + isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) + +Check if a forward ray, defined by a point `p` and a direction `d` intersects a bounding volume; +`p` and `d` can be any iterables with 3 numbers (e.g. `Vector{Float64}`). + +# Examples +Simple ray bounding box intersection example: + +```jldoctest +using ImplicitBVH +using ImplicitBVH: BSphere, BBox, isintersection + +# Generate a simple bounding box +bounding_box = BBox((0., 0., 0.), (1., 1., 1.)) + +# Generate a ray passing up and through the bottom face of the bounding box +point = [.5, .5, -10] +direction = [0, 0, 1] +isintersection(bounding_box, point, direction) + +# output +true +``` + +Simple ray bounding sphere intersection example: +```jldoctest +using ImplicitBVH +using ImplicitBVH: BSphere, BBox, isintersection + +# Generate a simple bounding sphere +bounding_sphere = BSphere((0., 0., 0.), 0.5) + +# Generate a ray passing up and through the bounding sphere +point = [0, 0, -10] +direction = [0, 0, 1] +isintersection(bounding_sphere, point, direction) + +# output +true +``` +""" +function isintersection end + + +""" + center(b::BSphere) + center(b::BBox{T}) where T + +Get the coordinates of a bounding volume's centre, as a NTuple{3, T}. +""" +function center end + + +""" + translate(b::BSphere{T}, dx) where T + translate(b::BBox{T}, dx) where T + +Get a new bounding volume translated by dx; dx can be any iterable with 3 elements. +""" +function translate end + + +# Sub-includes +include("bsphere.jl") +include("bbox.jl") +include("merge.jl") +include("iscontact.jl") +include("isintersection.jl") diff --git a/src/bounding_volumes/bsphere.jl b/src/bounding_volumes/bsphere.jl new file mode 100644 index 0000000..ed51fae --- /dev/null +++ b/src/bounding_volumes/bsphere.jl @@ -0,0 +1,146 @@ +""" + $(TYPEDEF) + +Bounding sphere, highly optimised for computing bounding volumes for triangles and merging into +larger bounding volumes. + +# Methods + # Convenience constructors + BSphere(x::NTuple{3, T}, r) + BSphere{T}(x::AbstractVector, r) where T + BSphere(x::AbstractVector, r) + + # Construct from triangle vertices + BSphere{T}(p1, p2, p3) where T + BSphere(p1, p2, p3) + BSphere{T}(vertices::AbstractMatrix) where T + BSphere(vertices::AbstractMatrix) + BSphere{T}(triangle) where T + BSphere(triangle) + + # Merging bounding volumes + BSphere{T}(a::BSphere, b::BSphere) where T + BSphere(a::BSphere{T}, b::BSphere{T}) where T + Base.:+(a::BSphere, b::BSphere) +""" +struct BSphere{T} + x::NTuple{3, T} + r::T +end + +Base.eltype(::BSphere{T}) where T = T +Base.eltype(::Type{BSphere{T}}) where T = T + + +# Convenience constructors, with and without type parameter +BSphere{T}(x::AbstractVector, r) where T = BSphere(NTuple{3, T}(x), T(r)) +BSphere(x::AbstractVector, r) = BSphere{eltype(x)}(x, r) + + +# Constructors from triangles +function BSphere{T}(p1, p2, p3) where T + + # Adapted from https://realtimecollisiondetection.net/blog/?p=20 + a = (T(p1[1]), T(p1[2]), T(p1[3])) + b = (T(p2[1]), T(p2[2]), T(p2[3])) + c = (T(p3[1]), T(p3[2]), T(p3[3])) + + # Unrolled dot(b - a, b - a) + abab = (b[1] - a[1]) * (b[1] - a[1]) + + (b[2] - a[2]) * (b[2] - a[2]) + + (b[3] - a[3]) * (b[3] - a[3]) + + # Unrolled dot(b - a, c - a) + abac = (b[1] - a[1]) * (c[1] - a[1]) + + (b[2] - a[2]) * (c[2] - a[2]) + + (b[3] - a[3]) * (c[3] - a[3]) + + # Unrolled dot(c - a, c - a) + acac = (c[1] - a[1]) * (c[1] - a[1]) + + (c[2] - a[2]) * (c[2] - a[2]) + + (c[3] - a[3]) * (c[3] - a[3]) + + d = T(2.) * (abab * acac - abac * abac) + + if abs(d) <= eps(T) + # a, b, c lie on a line. Find line centre and radius + lower = (minimum3(a[1], b[1], c[1]), + minimum3(a[2], b[2], c[2]), + minimum3(a[3], b[3], c[3])) + + upper = (maximum3(a[1], b[1], c[1]), + maximum3(a[2], b[2], c[2]), + maximum3(a[3], b[3], c[3])) + + centre = (T(0.5) * (lower[1] + upper[1]), + T(0.5) * (lower[2] + upper[2]), + T(0.5) * (lower[3] + upper[3])) + radius = dist3(centre, upper) + else + s = (abab * acac - acac * abac) / d + t = (acac * abab - abab * abac) / d + + if s <= zero(T) + centre = (T(0.5) * (a[1] + c[1]), + T(0.5) * (a[2] + c[2]), + T(0.5) * (a[3] + c[3])) + radius = dist3(centre, a) + elseif t <= zero(T) + centre = (T(0.5) * (a[1] + b[1]), + T(0.5) * (a[2] + b[2]), + T(0.5) * (a[3] + b[3])) + radius = dist3(centre, a) + elseif s + t >= one(T) + centre = (T(0.5) * (b[1] + c[1]), + T(0.5) * (b[2] + c[2]), + T(0.5) * (b[3] + c[3])) + radius = dist3(centre, b) + else + centre = (a[1] + s * (b[1] - a[1]) + t * (c[1] - a[1]), + a[2] + s * (b[2] - a[2]) + t * (c[2] - a[2]), + a[3] + s * (b[3] - a[3]) + t * (c[3] - a[3])) + radius = dist3(centre, a) + end + end + + BSphere(centre, radius) +end + + +# Convenience constructors, with and without explicit type parameter +function BSphere(p1, p2, p3) + BSphere{eltype(p1)}(p1, p2, p3) +end + +function BSphere{T}(triangle) where T + # Decompose triangle into its 3 vertices. + # Works transparently with GeometryBasics.Triangle, Vector{SVector{3, T}}, etc. + p1, p2, p3 = triangle + BSphere{T}(p1, p2, p3) +end + +function BSphere(triangle) + p1, p2, p3 = triangle + BSphere{eltype(p1)}(p1, p2, p3) +end + +function BSphere{T}(vertices::AbstractMatrix) where T + BSphere{T}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) +end + +function BSphere(vertices::AbstractMatrix) + BSphere{eltype(vertices)}(@view(vertices[:, 1]), @view(vertices[:, 2]), @view(vertices[:, 3])) +end + + +# Overloaded center function +center(b::BSphere) = b.x + + +# Overloaded translate function +function translate(b::BSphere{T}, dx) where T + new_center = (b.x[1] + T(dx[1]), + b.x[2] + T(dx[2]), + b.x[3] + T(dx[3])) + BSphere{T}(new_center, b.r) +end diff --git a/src/bounding_volumes/iscontact.jl b/src/bounding_volumes/iscontact.jl new file mode 100644 index 0000000..ca9fe98 --- /dev/null +++ b/src/bounding_volumes/iscontact.jl @@ -0,0 +1,28 @@ +# Contact detection +function iscontact(a::BSphere, b::BSphere) + dist3sq(a.x, b.x) <= (a.r + b.r) * (a.r + b.r) +end + + +function iscontact(a::BBox, b::BBox) + (a.up[1] >= b.lo[1] && a.lo[1] <= b.up[1]) && + (a.up[2] >= b.lo[2] && a.lo[2] <= b.up[2]) && + (a.up[3] >= b.lo[3] && a.lo[3] <= b.up[3]) +end + + +# Contact detection between heterogeneous BVs - only needed when one BVH has exactly one leaf +function iscontact(a::BSphere, b::BBox) + # This is an edge case, used for broad-phase collision detection, so we simply take the + # sphere's bounding box, as a full sphere-box contact detection is computationally heavy + ab = BBox( + (a.x[1] - a.r, a.x[2] - a.r, a.x[3] - a.r), + (a.x[1] + a.r, a.x[2] + a.r, a.x[3] + a.r), + ) + iscontact(ab, b) +end + + +function iscontact(a::BBox, b::BSphere) + iscontact(b, a) +end diff --git a/src/bounding_volumes/isintersection.jl b/src/bounding_volumes/isintersection.jl new file mode 100644 index 0000000..734459d --- /dev/null +++ b/src/bounding_volumes/isintersection.jl @@ -0,0 +1,72 @@ +@inline function isintersection(b::BBox, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + T = eltype(d) + + @inbounds begin + inv_d = (one(T) / d[1], one(T) / d[2], one(T) / d[3]) + + # Set x bounds + t_bound_x1 = (b.lo[1] - p[1]) * inv_d[1] + t_bound_x2 = (b.up[1] - p[1]) * inv_d[1] + + tmin = minimum2(t_bound_x1, t_bound_x2) + tmax = maximum2(t_bound_x1, t_bound_x2) + + # Set y bounds + t_bound_y1 = (b.lo[2] - p[2]) * inv_d[2] + t_bound_y2 = (b.up[2] - p[2]) * inv_d[2] + + tmin = maximum2(tmin, minimum2(t_bound_y1, t_bound_y2)) + tmax = minimum2(tmax, maximum2(t_bound_y1, t_bound_y2)) + + # Set z bounds + t_bound_z1 = (b.lo[3] - p[3]) * inv_d[3] + t_bound_z2 = (b.up[3] - p[3]) * inv_d[3] + + tmin = maximum2(tmin, minimum2(t_bound_z1, t_bound_z2)) + tmax = minimum2(tmax, maximum2(t_bound_z1, t_bound_z2)) + end + + # If condition satisfied ray intersects box. tmax >= 0 + # ensure only forwards intersections are counted + (tmin <= tmax) && (tmax >= 0) +end + + +@inline function isintersection(s::BSphere, p::AbstractVector, d::AbstractVector) + + @boundscheck begin + @assert length(p) == 3 + @assert length(d) == 3 + end + + T = eltype(d) + + @inbounds begin + a = dot3(d, d) + b = T(2) * ( + (p[1] - s.x[1]) * d[1] + + (p[2] - s.x[2]) * d[2] + + (p[3] - s.x[3]) * d[3] + ) + c = ( + (p[1] - s.x[1]) * (p[1] - s.x[1]) + + (p[2] - s.x[2]) * (p[2] - s.x[2]) + + (p[3] - s.x[3]) * (p[3] - s.x[3]) + ) - s.r * s.r + end + + discriminant = b * b - T(4) * a * c + + if discriminant >= T(0) + # Ensure only forwards intersections are counted + return T(0) >= b * c + else + return false + end +end \ No newline at end of file diff --git a/src/bounding_volumes/merge.jl b/src/bounding_volumes/merge.jl new file mode 100644 index 0000000..6b6a507 --- /dev/null +++ b/src/bounding_volumes/merge.jl @@ -0,0 +1,85 @@ +# Merge two bounding spheres +function BSphere{T}(a::BSphere, b::BSphere) where T + length = dist3(a.x, b.x) + + # a is enclosed within b + if length + a.r <= b.r + return BSphere{T}(b.x, b.r) + + # b is enclosed within a + elseif length + b.r <= a.r + return BSphere{T}(a.x, a.r) + + # Bounding spheres are not enclosed + else + frac = T(0.5) * ((b.r - a.r) / length + T(1)) + centre = (a.x[1] + frac * (b.x[1] - a.x[1]), + a.x[2] + frac * (b.x[2] - a.x[2]), + a.x[3] + frac * (b.x[3] - a.x[3])) + radius = T(0.5) * (length + a.r + b.r) + return BSphere{T}(centre, radius) + end +end + + +BSphere(a::BSphere{T}, b::BSphere{T}) where T = BSphere{T}(a, b) +Base.:+(a::BSphere, b::BSphere) = BSphere(a, b) + + +# Merge two bounding boxes +function BBox{T}(a::BBox, b::BBox) where T + lower = (minimum2(a.lo[1], b.lo[1]), + minimum2(a.lo[2], b.lo[2]), + minimum2(a.lo[3], b.lo[3])) + + upper = (maximum2(a.up[1], b.up[1]), + maximum2(a.up[2], b.up[2]), + maximum2(a.up[3], b.up[3])) + + BBox{T}(lower, upper) +end + +BBox(a::BBox{T}, b::BBox{T}) where T = BBox{T}(a, b) +Base.:+(a::BBox, b::BBox) = BBox(a, b) + + +# Convert BSphere to BBox +function BBox{T}(a::BSphere{T}) where T + lower = (a.x[1] - a.r, a.x[2] - a.r, a.x[3] - a.r) + upper = (a.x[1] + a.r, a.x[2] + a.r, a.x[3] + a.r) + BBox(lower, upper) +end + +function BBox(a::BSphere{T}) where T + BBox{T}(a) +end + +# Merge two BSphere into enclosing BBox +function BBox{T}(a::BSphere{T}, b::BSphere{T}) where T + length = dist3(a.x, b.x) + + # a is enclosed within b + if length + a.r <= b.r + return BBox(b) + + # b is enclosed within a + elseif length + b.r <= a.r + return BBox(a) + + # Bounding spheres are not enclosed + else + lower = (minimum2(a.x[1] - a.r, b.x[1] - b.r), + minimum2(a.x[2] - a.r, b.x[2] - b.r), + minimum2(a.x[3] - a.r, b.x[3] - b.r)) + + upper = (maximum2(a.x[1] + a.r, b.x[1] + b.r), + maximum2(a.x[2] + a.r, b.x[2] + b.r), + maximum2(a.x[3] + a.r, b.x[3] + b.r)) + + return BBox(lower, upper) + end +end + +function BBox(a::BSphere{T}, b::BSphere{T}) where T + BBox{T}(a, b) +end