Skip to content

Commit

Permalink
Tar generator timings.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Mar 22, 2024
1 parent d3f4740 commit 9af2380
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:]
print(f"Starting worker {i_proc}, id {proc_id}")

if self._timings:
start_time = time.perf_counter()
n_sample_tot = 0
start_time = time.perf_counter()
total_bytes = 0
n_sample_total = 0

for sample_list in sample_lists:

Expand Down Expand Up @@ -240,9 +240,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m
rots = sample_list["rots"][i_sample]
pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample])
pot *= -1 # Unit conversion, eV -> V
total_bytes += pot.nbytes
if use_rho:
rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample])
rho_shape = rho.shape
total_bytes += rho.nbytes
else:
lvec_rho = None
rho_shape = None
Expand Down Expand Up @@ -272,7 +274,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m

if self._timings:
t3 = time.perf_counter()
n_sample_tot += 1
n_sample_total += 1
print(
f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Wait-unlink: "
f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} "
Expand All @@ -290,7 +292,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m

if self._timings:
dt = time.perf_counter() - start_time
print(f"[Worker {i_proc}]: Loaded {n_sample_tot} samples in {dt}s. Average load time: {dt / n_sample_tot}s.")
print(
f"[Worker {i_proc}]: Loaded {n_sample_total} samples in {dt}s, totaling {total_bytes / 2**30:.3f}GiB. "
f"Average load time: {dt / n_sample_total}s."
)

def _get_queue_sample(self):

Expand Down Expand Up @@ -331,19 +336,24 @@ def _get_queue_sample(self):

def _yield_samples(self):

for _ in range(len(self)):
start_time = time.perf_counter()
n_sample_yielded = 0

n_sample_total = sum([len(sample_list["rots"]) for sample_list in self.samples])

for _ in range(n_sample_total):

if self._timings:
t0 = time.perf_counter()

i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id = self._get_queue_sample()

if self._timings:
t1 = time.perf_counter()

for rot in rots:
sample_dict = {"xyzs": xyzs, "Zs": Zs, "qs": pot, "rho_sample": rho, "rot": rot}
yield sample_dict
n_sample_yielded += 1

if self._timings:
t2 = time.perf_counter()
Expand All @@ -358,6 +368,10 @@ def _yield_samples(self):
t3 = time.perf_counter()
print(f"[Main, id {sample_id}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")

if self._timings:
dt = time.perf_counter() - start_time
print(f"[Main]: Yielded {n_sample_yielded} samples in {dt}s. Average yield time: {dt / n_sample_yielded}s.")


def _put_to_shared_memory(array, name):
shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes, name=name)
Expand Down

0 comments on commit 9af2380

Please sign in to comment.