Skip to content

Commit

Permalink
WIP: Neighborhood evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Jan 3, 2025
1 parent 5f4f6c2 commit aababf3
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 28 deletions.
4 changes: 2 additions & 2 deletions examples/neighbour_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rlst::{
operator::interface::DistributedArrayVectorSpace, rlst_dynamic_array2, AsApply, Element,
LinearSpace, OperatorBase, RawAccess,
LinearSpace, RawAccess,
};

fn main() {
Expand Down Expand Up @@ -56,5 +56,5 @@ fn main() {
.local_mut()
.fill_from_equally_distributed(&mut rng);

//let res = evaluator.apply(&x);
let res = evaluator.apply(&x);
}
217 changes: 191 additions & 26 deletions src/evaluator_tools.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
//! Various helper functions to support evaluators.
use std::marker::PhantomData;
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
};

use green_kernels::traits::Kernel;
use green_kernels::{traits::Kernel, types::GreenKernelEvalType};
use itertools::{izip, Itertools};
use mpi::traits::{Communicator, Equivalence};
use ndelement::traits::FiniteElement;
use ndgrid::{
traits::{Entity, GeometryMap, Grid, ParallelGrid},
traits::{Entity, GeometryMap, Grid, ParallelGrid, Topology},
types::Ownership,
};

use rayon::prelude::*;
use rayon::{prelude::*, range};

Check failure on line 17 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused import: `range`

Check failure on line 17 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused import: `range`
use rlst::{
operator::interface::{
distributed_sparse_operator::DistributedCsrMatrixOperator, DistributedArrayVectorSpace,
Expand Down Expand Up @@ -172,9 +175,118 @@ where
}

/// A linear operator that evaluates kernel interactions for a nonsingular quadrature rule on neighbouring elements.
///
/// This creates a linear operator which for a given kernel evaluates all the interactions between neighbouring elements on a set
/// of local points per triangle.
pub fn neighbour_operator<
'a,
C: Communicator,
T: RlstScalar + Equivalence,
K: Kernel<T = T>,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
GridImpl: ParallelGrid<C, T = T::Real>,
>(
grid: &'a GridImpl,
kernel: K,

Check failure on line 188 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `kernel`

Check failure on line 188 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `kernel`
eval_type: GreenKernelEvalType,

Check failure on line 189 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `eval_type`

Check failure on line 189 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `eval_type`
eval_points: &[T::Real],
domain_space: &'a DistributedArrayVectorSpace<'a, DomainLayout, T>,
range_space: &'a DistributedArrayVectorSpace<'a, RangeLayout, T>,
) where
GridImpl::LocalGrid: Sync,
for<'b> <GridImpl::LocalGrid as Grid>::GeometryMap<'b>: Sync,
{
// Check that the domain space and range space are compatible with the grid.
// Topological dimension of the grid.
let tdim = grid.topology_dim();

// Also need the geometric dimension.
let gdim = grid.geometry_dim();

// The method is currently restricted to single element type grids.
// So let's make sure that this is the case.

assert_eq!(grid.entity_types(tdim).len(), 1);

let reference_cell = grid.entity_types(tdim)[0];

// Get the number of points
assert_eq!(eval_points.len() % tdim, 0);
let n_points = eval_points.len() / tdim;

// The active cells are those that we need to iterate over.
// At the moment these are simply all owned cells in the grid.
// We sort the active cells by global index. This is important so that in the evaluation
// we can just iterate the output vector through in chunks and know from the chunk which
// active cell it is associated with.

let active_cells: Vec<usize> = grid
.entity_iter(tdim)
.filter(|e| matches!(e.ownership(), Ownership::Owned))
.sorted_by_key(|e| e.global_index())
.map(|e| e.local_index())
.collect_vec();

let n_cells = active_cells.len();

// Check that domain space and function space are compatible with the grid.

assert_eq!(
domain_space.index_layout().number_of_local_indices(),
n_cells * n_points
);

assert_eq!(
range_space.index_layout().number_of_local_indices(),
n_cells * n_points
);

// We need to figure out how many data entries the sparse matrix will have.
// For this we need to iterate through all cells and figure out how many neighbours
// each cell has.

let mut entries = 0;
for &cell in &active_cells {
let n_neighbours = grid
.entity(tdim, cell)
.unwrap()
.topology()
.connected_entity_iter(tdim)
.count();
// For each cell have (1 + nneighbours) * n_points * n_points interactions.
// The + comes from the fact that we also need to consider self interactions for each cell.
entries += (1 + n_neighbours) * n_points * n_points;
}

// Now initiate the aij arrays for the sparse matrix.

let rows = Arc::new(Mutex::new(Vec::<usize>::with_capacity(entries)));

Check failure on line 261 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `rows`

Check failure on line 261 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `rows`
let cols = Arc::new(Mutex::new(Vec::<usize>::with_capacity(entries)));

Check failure on line 262 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `cols`

Check failure on line 262 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `cols`
let data = Arc::new(Mutex::new(Vec::<T>::with_capacity(entries)));

Check failure on line 263 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `data`

Check failure on line 263 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `data`

let local_grid = grid.local_grid();
let geometry_map = local_grid.geometry_map(reference_cell, eval_points);

// We now use a parallel iterator to go through each cell and evaluate the interactions.
active_cells.par_iter().for_each(|&cell| {
let cell_entity = local_grid.entity(tdim, cell).unwrap();
let cell_global_index = cell_entity.global_index();
let mut target_points = rlst_dynamic_array2!(T::Real, [gdim, n_points]);
geometry_map.points(cell, target_points.data_mut());
// We also need the corresponding indices.
let target_indices =

Check failure on line 275 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `target_indices`

Check failure on line 275 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `target_indices`
(n_points * cell_global_index..n_points * (1 + cell_global_index)).collect_vec();
// Let's get the neighbouring cells. These are the sources.
let mut source_cells = cell_entity
.topology()
.connected_entity_iter(tdim)
.collect_vec();
// We push the cell itself for the self interactions.
source_cells.push(cell);
// The following matrix is going to hold the result of the kernel calculation.
let mut kernel_matrix = rlst_dynamic_array2!(T, [n_points, source_cells.len() * n_points]);

Check failure on line 285 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

unused variable: `kernel_matrix`

Check failure on line 285 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

variable does not need to be mutable

Check failure on line 285 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

unused variable: `kernel_matrix`

Check failure on line 285 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Run Rust tests (stable, openmpi, --features "strict")

variable does not need to be mutable
// We also want the global indices of the source points.
})
}

pub struct NeighbourEvaluator<

Check failure on line 290 in src/evaluator_tools.rs

View workflow job for this annotation

GitHub Actions / Build docs

missing documentation for a struct
'a,
T: RlstScalar + Equivalence,
Expand Down Expand Up @@ -305,14 +417,55 @@ impl<
}
}

// struct SyncedGridRef<
// 'a,
// C: Communicator,
// T: RlstScalar + Equivalence,
// GridImpl: ParallelGrid<C, T = T::Real>,
// > {
// grid: &'a GridImpl::LocalGrid,
// _marker: PhantomData<C>,
// _marker2: PhantomData<T>,
// }

// unsafe impl<'a, C: Communicator, T: RlstScalar + Equivalence, GridImpl: ParallelGrid<C, T = T::Real>>
// Sync for SyncedGridRef<'a, C, T, GridImpl>
// {
// }

// impl<'a, C: Communicator, T: RlstScalar + Equivalence, GridImpl: ParallelGrid<C, T = T::Real>>
// SyncedGridRef<'a, C, T, GridImpl>
// {
// fn new(grid: &'a GridImpl::LocalGrid) -> Self {
// Self {
// grid,
// _marker: PhantomData,
// _marker2: PhantomData,
// }
// }
// }

// impl<'a, C: Communicator, T: RlstScalar + Equivalence, GridImpl: ParallelGrid<C, T = T::Real>>
// std::ops::Deref for SyncedGridRef<'a, C, T, GridImpl>
// {
// type Target = GridImpl::LocalGrid;

// fn deref(&self) -> &Self::Target {
// self.grid
// }
// }

impl<
T: RlstScalar + Equivalence,
K: Kernel<T = T>,
DomainLayout: IndexLayout<Comm = C> + Sync,
RangeLayout: IndexLayout<Comm = C> + Sync,
GridImpl: ParallelGrid<C, T = T::Real> + Sync,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
GridImpl: ParallelGrid<C, T = T::Real>,
C: Communicator,
> AsApply for NeighbourEvaluator<'_, T, K, DomainLayout, RangeLayout, GridImpl, C>
where
GridImpl::LocalGrid: Sync,
for<'b> <GridImpl::LocalGrid as Grid>::GeometryMap<'b>: Sync,
{
fn apply_extended(
&self,
Expand All @@ -324,30 +477,42 @@ impl<
// We need to iterate through the elements.

let tdim = self.grid.topology_dim();
let gdim = self.grid.geometry_dim();

// In the new function we already made sure that eval_points is a multiple of tdim.
let n_points = self.eval_points.len() / tdim;

// We need the reference cell
let reference_cell = self.grid.entity_types(tdim)[0];

// We go through groups of target dofs in chunks of lenth n_points.
// This corresponds to iteration in active cells since we ordered those
// already by global index.

let raw_ptr = y.view_mut().local_mut().data_mut().as_mut_ptr();

// y.view_mut()
// .local_mut()
// .data_mut()
// .par_chunks_mut(n_points)
// .enumerate()
// .for_each(|(chunk_index, chunk)| {
// let cell = self.active_cells[chunk_index];
// let cell_entity = self.grid.entity(tdim, cell).unwrap();

// // Get the geometry map for the cell.
// let geometry_map = self
// .grid
// .geometry_map(cell_entity.entity_type(), &self.eval_points);
// });
let local_grid = self.grid.local_grid();
let active_cells = self.active_cells.as_slice();
let eval_points = self.eval_points.as_slice();
let geometry_map = local_grid.geometry_map(reference_cell, eval_points);

y.view_mut()
.local_mut()
.data_mut()
.par_chunks_mut(n_points)
.zip(active_cells)
.enumerate()
.for_each(|(chunk_index, (chunk, &active_cell_index))| {
let cell_entity = local_grid.entity(tdim, active_cell_index).unwrap();
let mut physical_points = rlst_dynamic_array2![T::Real, [gdim, n_points]];

// Get the geometry map for the cell.
geometry_map.points(active_cell_index, physical_points.data_mut());

// We now have to iterate through all neighbouring entities of the cell.

for other_cell in cell_entity.topology().connected_entity_iter(tdim) {
// Get the points of the other cell.
}
});

Ok(())
}
Expand Down

0 comments on commit aababf3

Please sign in to comment.