From 9af2380ea4bf4fb0129b464deed066c97b7644bb Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Fri, 22 Mar 2024 20:34:53 +0200 Subject: [PATCH] Tar generator timings. --- mlspm/data_generation.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py index 55b2af2..32e1c99 100644 --- a/mlspm/data_generation.py +++ b/mlspm/data_generation.py @@ -206,9 +206,9 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:] print(f"Starting worker {i_proc}, id {proc_id}") - if self._timings: - start_time = time.perf_counter() - n_sample_tot = 0 + start_time = time.perf_counter() + total_bytes = 0 + n_sample_total = 0 for sample_list in sample_lists: @@ -240,9 +240,11 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m rots = sample_list["rots"][i_sample] pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample]) pot *= -1 # Unit conversion, eV -> V + total_bytes += pot.nbytes if use_rho: rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample]) rho_shape = rho.shape + total_bytes += rho.nbytes else: lvec_rho = None rho_shape = None @@ -272,7 +274,7 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m if self._timings: t3 = time.perf_counter() - n_sample_tot += 1 + n_sample_total += 1 print( f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Wait-unlink: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} " @@ -290,7 +292,10 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m if self._timings: dt = time.perf_counter() - start_time - print(f"[Worker {i_proc}]: Loaded {n_sample_tot} samples in {dt}s. Average load time: {dt / n_sample_tot}s.") + print( + f"[Worker {i_proc}]: Loaded {n_sample_total} samples in {dt}s, totaling {total_bytes / 2**30:.3f}GiB. " + f"Average load time: {dt / n_sample_total}s." + ) def _get_queue_sample(self): @@ -331,19 +336,24 @@ def _get_queue_sample(self): def _yield_samples(self): - for _ in range(len(self)): + start_time = time.perf_counter() + n_sample_yielded = 0 + + n_sample_total = sum([len(sample_list["rots"]) for sample_list in self.samples]) + + for _ in range(n_sample_total): if self._timings: t0 = time.perf_counter() i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id = self._get_queue_sample() - if self._timings: t1 = time.perf_counter() for rot in rots: sample_dict = {"xyzs": xyzs, "Zs": Zs, "qs": pot, "rho_sample": rho, "rot": rot} yield sample_dict + n_sample_yielded += 1 if self._timings: t2 = time.perf_counter() @@ -358,6 +368,10 @@ def _yield_samples(self): t3 = time.perf_counter() print(f"[Main, id {sample_id}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}") + if self._timings: + dt = time.perf_counter() - start_time + print(f"[Main]: Yielded {n_sample_yielded} samples in {dt}s. Average yield time: {dt / n_sample_yielded}s.") + def _put_to_shared_memory(array, name): shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes, name=name)