Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Particle Container to Pure SoA #348

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/dependencies/ABLASTR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ set(ImpactX_openpmd_src ""
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
CACHE STRING
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
set(ImpactX_ablastr_branch "24.02"
set(ImpactX_ablastr_branch "11aabdca56335c5ae1cbb2257b8abd6c8f04a67c"
CACHE STRING
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")

Expand Down
2 changes: 1 addition & 1 deletion cmake/dependencies/pyAMReX.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ option(ImpactX_pyamrex_internal "Download & build pyAMReX" ON)
set(ImpactX_pyamrex_repo "https://github.com/AMReX-Codes/pyamrex.git"
CACHE STRING
"Repository URI to pull and build pyamrex from if(ImpactX_pyamrex_internal)")
set(ImpactX_pyamrex_branch "24.02"
set(ImpactX_pyamrex_branch "5aa700de18a61f933cb435adbe2299d74d794d6b"
CACHE STRING
"Repository branch for ImpactX_pyamrex_repo if(ImpactX_pyamrex_internal)")

Expand Down
2 changes: 1 addition & 1 deletion examples/epac2004_benchmarks/input_fodo_rf_SC.in
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ geometry.prob_relative = 4.0
###############################################################################
# Diagnostics
###############################################################################
diag.slice_step_diagnostics = true
diag.slice_step_diagnostics = false
22 changes: 11 additions & 11 deletions examples/fodo/run_fodo_programmable.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def my_drift(pge, pti, refpart):

else:
array = np.array
# access AoS data such as positions and cpu/id
aos = pti.aos()
aos_arr = array(aos, copy=False)

# access SoA data such as momentum
# access particle attributes
soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
real_arrays = soa.get_real_data()
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)

# length of the current slice
slice_ds = pge.ds / pge.nslice
Expand All @@ -96,9 +96,9 @@ def my_drift(pge, pti, refpart):
betgam2 = pt_ref**2 - 1.0

# advance position and momentum (drift)
aos_arr[:]["x"] += slice_ds * px[:]
aos_arr[:]["y"] += slice_ds * py[:]
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
x[:] += slice_ds * px[:]
y[:] += slice_ds * py[:]
t[:] += (slice_ds / betgam2) * pt[:]


def my_ref_drift(pge, refpart):
Expand Down
102 changes: 73 additions & 29 deletions examples/pytorch_surrogate_model/run_ml_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from urllib import request

import numpy as np

try:
import cupy as cp

cupy_available = True
except ImportError:
cupy_available = False

from surrogate_model_definitions import surrogate_model

try:
Expand All @@ -20,14 +28,34 @@
sys.exit(0)

from impactx import (
Config,
CoordSystem,
ImpactX,
ImpactXParIter,
TransformationDirection,
coordinate_transformation,
distribution,
elements,
)

# CPU/GPU logic
if Config.have_gpu:
if cupy_available:
array = cp.array
stack = cp.stack
device = torch.device("cuda")
else:
print("Warning: GPU found but cupy not available! Try managed...")
array = np.array
stack = np.stack
device = torch.device("cpu")
if Config.gpu_backend == "SYCL":
print("Warning: SYCL GPU backend not yet implemented for Python")

else:
array = np.array
stack = np.stack
device = torch.device("cpu")


def download_and_unzip(url, data_dir):
request.urlretrieve(url, data_dir)
Expand All @@ -50,6 +78,7 @@ def download_and_unzip(url, data_dir):
surrogate_model(
dataset_dir + f"dataset_beam_stage_{i}.pt",
model_dir + f"beam_stage_{i}_model.pt",
device=device,
)
for i in range(N_stage)
]
Expand Down Expand Up @@ -78,47 +107,62 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
self.ds = surrogate_length

def surrogate_push(self, pc, step):
array = np.array

ref_part = pc.ref_particle()
ref_z_i = ref_part.z
ref_z_i_LPA = ref_z_i - self.stage_start
ref_z_f = ref_z_i + self.surrogate_length

ref_part_tensor = torch.tensor(
[ref_part.x, ref_part.y, ref_z_i_LPA, ref_part.px, ref_part.py, ref_part.pz]
[
ref_part.x,
ref_part.y,
ref_z_i_LPA,
ref_part.px,
ref_part.py,
ref_part.pz,
],
dtype=torch.float64,
device=device,
)
ref_beta_gamma = np.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
ref_beta_gamma = torch.sqrt(torch.sum(ref_part_tensor[3:] ** 2))

with torch.no_grad():
ref_part_model_final = self.surrogate_model(ref_part_tensor.float())
ref_part_model_final = self.surrogate_model(ref_part_tensor)
ref_uz_f = ref_part_model_final[5]
ref_beta_gamma_final = (
ref_uz_f # NOT np.sqrt(torch.sum(ref_part_model_final[3:]**2))
)
ref_part_final = torch.tensor([0, 0, ref_z_f, 0, 0, ref_uz_f])
ref_part_final = torch.tensor(
[0, 0, ref_z_f, 0, 0, ref_uz_f], dtype=torch.float64, device=device
)

# transform
coordinate_transformation(pc, TransformationDirection.to_fixed_t)
coordinate_transformation(pc, direction=CoordSystem.t)

for lvl in range(pc.finest_level + 1):
for pti in ImpactXParIter(pc, level=lvl):
aos = pti.aos()
aos_arr = array(aos, copy=False)

soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
data_arr = (
torch.tensor(
np.vstack(
[aos_arr["x"], aos_arr["y"], aos_arr["z"], real_arrays[:3]]
)
)
.float()
.T
real_arrays = soa.get_real_data()
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)
data_arr = torch.tensor(
stack(
Fixed Show fixed Hide fixed
[
x,
y,
t,
px,
py,
py,
],
axis=1,
),
dtype=torch.float64,
device=device,
Fixed Show fixed Hide fixed
)

data_arr[:, 0] += ref_part.x
Expand All @@ -135,7 +179,7 @@ def surrogate_push(self, pc, step):
# # assume for now it is

with torch.no_grad():
data_arr_post_model = self.surrogate_model(data_arr.float())
data_arr_post_model = self.surrogate_model(data_arr)

# need to add stage start to z
data_arr_post_model[:, 2] += self.stage_start
Expand All @@ -146,9 +190,9 @@ def surrogate_push(self, pc, step):
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final

aos_arr["x"] = data_arr_post_model[:, 0]
aos_arr["y"] = data_arr_post_model[:, 1]
aos_arr["z"] = data_arr_post_model[:, 2]
x[:] = data_arr_post_model[:, 0]
y[:] = data_arr_post_model[:, 1]
t[:] = data_arr_post_model[:, 2]
px[:] = data_arr_post_model[:, 3]
py[:] = data_arr_post_model[:, 4]
pt[:] = data_arr_post_model[:, 5]
Expand All @@ -160,7 +204,7 @@ def surrogate_push(self, pc, step):
ref_part.x = ref_part_final[0]
ref_part.y = ref_part_final[1]
ref_part.z = ref_part_final[2]
ref_gamma = np.sqrt(1 + ref_beta_gamma_final**2)
ref_gamma = torch.sqrt(1 + ref_beta_gamma_final**2)
ref_part.px = ref_part_final[3]
ref_part.py = ref_part_final[4]
ref_part.pz = ref_part_final[5]
Expand All @@ -173,7 +217,7 @@ def surrogate_push(self, pc, step):
# ref_part.s += pge1.ds
# ref_part.t += pge1.ds / ref_beta

coordinate_transformation(pc, TransformationDirection.to_fixed_s)
coordinate_transformation(pc, direction=CoordSystem.s)
## Done!


Expand Down
26 changes: 17 additions & 9 deletions examples/pytorch_surrogate_model/surrogate_model_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act):
class surrogate_model:
""" """

def __init__(self, dataset_file, model_file):
def __init__(self, dataset_file, model_file, device):
self.dataset = torch.load(dataset_file)
model_dict = torch.load(model_file, map_location=torch.device("cpu"))
self.device = device
model_dict = torch.load(model_file)
n_in = model_dict["model_state_dict"]["stack.0.weight"].shape[1]
final_layer_key = list(model_dict["model_state_dict"].keys())[-1]
n_out = model_dict["model_state_dict"][final_layer_key].shape[0]
Expand All @@ -112,13 +113,20 @@ def __init__(self, dataset_file, model_file):
self.neural_network.load_state_dict(model_dict["model_state_dict"])
self.neural_network.eval()

def __call__(self, data_arr):
data_arr -= self.dataset["source_means"]
data_arr /= self.dataset["source_stds"]
data_arr = data_arr.float()
def __call__(self, data_arr, device=None):
data_arr -= torch.tensor(
self.dataset["source_means"], dtype=torch.float64, device=device
)
data_arr /= torch.tensor(
self.dataset["source_stds"], dtype=torch.float64, device=device
)
with torch.no_grad():
data_arr_post_model = self.neural_network(data_arr)
data_arr_post_model = self.neural_network(data_arr.float()).double()

data_arr_post_model *= self.dataset["target_stds"]
data_arr_post_model += self.dataset["target_means"]
data_arr_post_model *= torch.tensor(
self.dataset["target_stds"], dtype=torch.float64, device=device
)
data_arr_post_model += torch.tensor(
self.dataset["target_means"], dtype=torch.float64, device=device
)
return data_arr_post_model
18 changes: 10 additions & 8 deletions src/particles/CollectLost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <AMReX_GpuLaunch.H>
#include <AMReX_GpuQualifiers.H>
#include <AMReX_Math.H>
#include <AMReX_Particle.H>
#include <AMReX_ParticleTransformation.H>
#include <AMReX_RandomEngine.H>

Expand All @@ -27,9 +28,9 @@ namespace impactx
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;

AMREX_GPU_HOST_DEVICE
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
dst.m_aos[dst_ip] = src.m_aos[src_ip];

void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
{
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
for (int j = 0; j < SrcData::NAR; ++j)
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
for (int j = 0; j < src.m_num_runtime_real; ++j)
Expand All @@ -42,7 +43,7 @@ namespace impactx
// dst.m_runtime_idata[j][dst_ip] = src.m_runtime_idata[j][src_ip];

// flip id to positive in destination
dst.id(dst_ip) = amrex::Math::abs(dst.id(dst_ip));
amrex::ParticleIDWrapper{dst.m_idcpu[dst_ip]}.make_valid();

// remember the current s of the ref particle when lost
dst.m_runtime_rdata[s_index][dst_ip] = s_lost;
Expand Down Expand Up @@ -85,7 +86,7 @@ namespace impactx
auto const predicate = [] AMREX_GPU_HOST_DEVICE (const SrcData& src, int ip)
/* NVCC 11.3.109 chokes in C++17 on this: noexcept */
{
return src.id(ip) < 0;
return !amrex::ConstParticleIDWrapper{src.m_idcpu[ip]}.is_valid();
};

auto& ptile_dest = dest.DefineAndReturnParticleTile(
Expand Down Expand Up @@ -130,9 +131,11 @@ namespace impactx
{
int n_removed = 0;
auto ptile_src_data = ptile_source.getParticleTileData();
auto const ptile_soa = ptile_source.GetStructOfArrays();
auto const ptile_idcpu = ptile_soa.GetIdCPUData().dataPtr();
for (int ip = 0; ip < np; ++ip)
{
if (ptile_source.id(ip) < 0)
if (!amrex::ConstParticleIDWrapper{ptile_idcpu[ip]}.is_valid())
n_removed++;
else
{
Expand All @@ -141,8 +144,7 @@ namespace impactx
// move down
int const new_index = ip - n_removed;

ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];

ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
for (int j = 0; j < SrcData::NAR; ++j)
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)
Expand Down
Loading
Loading