From 437ffd7dcf0bee108c42f223d5a45d2f9c480371 Mon Sep 17 00:00:00 2001 From: fabiocfabini Date: Tue, 19 Jul 2022 13:17:51 +0100 Subject: [PATCH] Improved performance in preprocessing.py --- python/preprocessing.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/python/preprocessing.py b/python/preprocessing.py index a7dbe81..4b74530 100644 --- a/python/preprocessing.py +++ b/python/preprocessing.py @@ -18,12 +18,16 @@ """ __version__ = "v0.3" +from multiprocessing import Pool +from itertools import product + import os import sys import time import xml.etree.ElementTree as ET import numpy as np + import contatempo import dft from headerfooter import header, footer @@ -32,6 +36,9 @@ # pylint: disable=C0103 ################################################################################### +def compute_rpoints(l, k, i): + return A1 * i / NR1 + A2 * k / NR2 + A3 * l / NR3 + if __name__ == "__main__": header("PREPROCESSING", __version__, time.asctime()) @@ -336,15 +343,17 @@ # Final calculations ######################################### print() - COUNT = 0 - RPOINT = np.zeros((NR, 3), dtype=float) - for l in range(NR3): - for k in range(NR2): - for i in range(NR1): - RPOINT[COUNT] = A1 * i / NR1 + A2 * k / NR2 + A3 * l / NR3 - COUNT += 1 + params = { + "NR3": range(NR3), + "NR2": range(NR2), + "NR1": range(NR1), + } + + + with Pool(processes=NPR) as pool: + RPOINT = np.array(pool.starmap(compute_rpoints, product(*params.values()))) - PHASE = np.exp(1j * np.dot(RPOINT, np.transpose(KPOINTS))) + PHASE = np.exp(1j * np.dot(RPOINT, KPOINTS.T)) # Start saving data to files ################################## with open("phase.npy", "wb") as fich: