Skip to content

Commit

Permalink
chunk x3flux in x1 dim
Browse files Browse the repository at this point in the history
  • Loading branch information
pgrete committed Sep 27, 2023
1 parent e2d1699 commit 44165f6
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/hydro/hydro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,32 +900,44 @@ TaskStatus CalculateFluxes(std::shared_ptr<MeshData<Real>> &md) {
if (pmb->pmy_mesh->ndim >= 3 && recon != Reconstruction::plm) {
// set the loop limits
il = ib.s - 1, iu = ib.e + 1, jl = jb.s - 1, ju = jb.e + 1;
const int chunck_size = 32;
int num_chunks = (iu - il) / chunck_size;
num_chunks += (iu - il) % chunck_size != 0 ? 1 : 0;
scratch_size_in_bytes =
parthenon::ScratchPad2D<Real>::shmem_size(num_scratch_vars, chunck_size) * 3;

parthenon::par_for_outer(
DEFAULT_OUTER_LOOP_PATTERN, "x3 flux", DevExecSpace(), scratch_size_in_bytes,
scratch_level, 0, cons_in.GetDim(5) - 1, jl, ju,
KOKKOS_LAMBDA(parthenon::team_mbr_t member, const int b, const int j) {
scratch_level, 0, cons_in.GetDim(5) - 1, jl, ju, 0, num_chunks - 1,
KOKKOS_LAMBDA(parthenon::team_mbr_t member, const int b, const int j,
const int ii) {
const auto &prim = prim_in(b);
auto &cons = cons_in(b);
parthenon::ScratchPad2D<Real> wl(member.team_scratch(scratch_level),
num_scratch_vars, nx1);
num_scratch_vars, chunck_size);
parthenon::ScratchPad2D<Real> wr(member.team_scratch(scratch_level),
num_scratch_vars, nx1);
num_scratch_vars, chunck_size);
parthenon::ScratchPad2D<Real> wlb(member.team_scratch(scratch_level),
num_scratch_vars, nx1);
num_scratch_vars, chunck_size);
const int il_ = il + ii * chunck_size;
int iu_ = il_ + chunck_size - 1;
if (iu_ > iu) {
iu_ = iu;
}

for (int k = kb.s - 1; k <= kb.e + 1; ++k) {
// reconstruct L/R states at j
Reconstruct<recon, X3DIR>(member, k, j, il, iu, prim, wlb, wr);
Reconstruct<recon, X3DIR>(member, k, j, il_, iu_, prim, wlb, wr);
// Sync all threads in the team so that scratch memory is consistent
member.team_barrier();

if (k > kb.s - 1) {
riemann.Solve(member, k, j, il, iu, IV3, wl, wr, cons, eos, c_h);
riemann.Solve(member, k, j, il_, iu_, IV3, wl, wr, cons, eos, c_h);
member.team_barrier();

// Passive scalar fluxes
for (auto n = nhydro; n < nhydro + nscalars; ++n) {
parthenon::par_for_inner(member, il, iu, [&](const int i) {
parthenon::par_for_inner(member, il_, iu_, [&](const int i) {
if (cons.flux(IV3, IDN, k, j, i) >= 0.0) {
cons.flux(IV3, n, k, j, i) = cons.flux(IV3, IDN, k, j, i) * wl(n, i);
} else {
Expand Down

0 comments on commit 44165f6

Please sign in to comment.