From e40bb849e98467d2802673364e1e79b5fee7c24b Mon Sep 17 00:00:00 2001 From: fabiocfabini Date: Tue, 19 Jul 2022 11:47:17 +0100 Subject: [PATCH 1/2] Improved preformance to dotproduct --- python/dotproduct.py | 128 ++++++++++++++++++++++++++----------------- 1 file changed, 77 insertions(+), 51 deletions(-) diff --git a/python/dotproduct.py b/python/dotproduct.py index a426def..fc4c327 100644 --- a/python/dotproduct.py +++ b/python/dotproduct.py @@ -2,47 +2,75 @@ This program calculates the dot product of the wfc Bloch factor with their neighbors """ +from itertools import product +from typing import Tuple, Optional +from multiprocessing import Array, Pool import sys import time +import ctypes +import multiprocessing import numpy as np # This are the subroutines and functions -import contatempo from headerfooter import header, footer -import loaddata as d +from NestablePool import NestablePool +import contatempo +import loaddata as d # pylint: disable=C0103 ################################################################################### -def connection(nkconn, neighborconn, dphaseconn): +def _connect(nk, j, neighbor, jNeighbor, dphase, band0, band1): + "Reads the data from file and " + + with open(f"{d.wfcdirectory}k0{nk}b0{band0}.wfc", "rb") as fichconn: + wfc0 = np.load(fichconn) + with open(f"{d.wfcdirectory}k0{neighbor}b0{band1}.wfc", "rb") as fichconn: + wfc1 = np.load(fichconn) + + dpc[nk, j, band0, band1] = np.sum(dphase * wfc0 * np.conjugate(wfc1)) / d.nr + dpc[neighbor, jNeighbor, band1, band0] = np.conjugate(dpc[nk, j, band0, band1]) + + +def pre_connection(nkconn, j, neighborconn, jNeighbor, dphaseconn): """Calculates the dot product of all combinations of wfc in nkconn and neighborconn.""" + params = { + "nkconn": (nkconn,), + "j": (j,), + "neighborconn": (neighborconn,), + "jNeighbor": (jNeighbor,), + "dphaseconn": (dphaseconn,), + "banda0": range(d.nbnd), + "banda1": range(d.nbnd), + } + + with Pool(processes=min(d.nbnd, multiprocessing.cpu_count())) as pool: + pool.starmap(_connect, product(*params.values())) + + +def connection( + nk: int, j: int, neighbor: int, jNeighbor: Tuple[np.ndarray] +) -> None: - dpc1 = np.zeros((d.nbnd, d.nbnd), dtype=complex) - dpc2 = np.zeros((d.nbnd, d.nbnd), dtype=complex) + dphase = d.phase[:, nk] * np.conjugate(d.phase[:, neighbor]) - for banda0 in range(d.nbnd): - # reads first file for dot product - infile = d.wfcdirectory + "k0" + str(nkconn) + "b0" + str(banda0) + ".wfc" - with open(infile, "rb") as fichconn: - wfc0 = np.load(fichconn) - fichconn.close() + print(" Calculating nk = " + str(nk) + " neighbor = " + str(neighbor)) + sys.stdout.flush() - for banda1 in range(d.nbnd): - # reads second file for dot product - infile = ( - d.wfcdirectory + "k0" + str(neighborconn) + "b0" + str(banda1) + ".wfc" - ) - with open(infile, "rb") as fichconn: - wfc1 = np.load(fichconn) - fichconn.close() + pre_connection(nk, j, neighbor, jNeighbor, dphase) - # calculates the dot products u_1.u_2* and u_2.u_1* - dpc1[banda0, banda1] = np.sum(dphaseconn * wfc0 * np.conjugate(wfc1)) / d.nr - dpc2[banda1, banda0] = np.conjugate(dpc1[banda0, banda1]) - return dpc1, dpc2 +def _generate_pre_connection_args( + nk: int, j: int +) -> Optional[Tuple[int, int, int, Tuple[np.ndarray]]]: + """Generates the arguments for the pre_connection function.""" + neighbor = d.neighbors[nk, j] + if neighbor != -1 and neighbor > nk: + jNeighbor = np.where(d.neighbors[neighbor] == nk) + return (nk, j, neighbor, jNeighbor) + return None ################################################################################### @@ -50,6 +78,12 @@ def connection(nkconn, neighborconn, dphaseconn): header("DOTPRODUCT", d.version, time.asctime()) STARTTIME = time.time() # Starts counting time + DPC_SIZE = d.nks * 4 * d.nbnd * d.nbnd + DPC_SHAPE = (d.nks, 4, d.nbnd, d.nbnd) + RUN_PARAMS = { + "nk_points": range(d.nks), + "num_neibhors": range(4), # TODO Fix Hardcoded value + } # Reading data needed for the run @@ -61,50 +95,42 @@ def connection(nkconn, neighborconn, dphaseconn): print(" Number of bands:", d.nbnd) print() print(" Phases loaded") - # print(d.phase[10000,10]) # d.phase[d.nr,d.nks] print(" Neighbors loaded") # Finished reading data needed for the run print() ########################################################## - dpc = np.full((d.nks, 4, d.nbnd, d.nbnd), 0 + 0j, dtype=complex) - dp = np.zeros((d.nks, 4, d.nbnd, d.nbnd)) + # Creating a buffer for the dpc np.ndarray + dpc_base = Array(ctypes.c_double, 2 * DPC_SIZE, lock=False) + # Initializing shared instance of np.ndarray 'dpc' + dpc = np.frombuffer(dpc_base, dtype=complex).reshape(DPC_SHAPE) - for nk in range(d.nks): # runs through all k-points - for j in range(4): # runs through all neighbors - neighbor = d.neighbors[nk, j] - - if neighbor != -1 and neighbor > nk: # exclude invalid neighbors - jNeighbor = np.where(d.neighbors[neighbor] == nk) - # Calculates the diference in phases to convert \psi to u - dphase = d.phase[:, nk] * np.conjugate(d.phase[:, neighbor]) - - print( - " Calculating nk = " - + str(nk) - + " neighbor = " - + str(neighbor) - ) - sys.stdout.flush() + ########################################################## - dpc[nk, j, :, :], dpc[neighbor, jNeighbor, :, :] = connection( - nk, neighbor, dphase - ) + # Creating a list of tuples with the neighbors of each k-point + with NestablePool(d.npr) as pool: + pre_connection_args = ( + filter( + None, + pool.starmap( + _generate_pre_connection_args, product(*RUN_PARAMS.values()) + ), + ), + ) + pool.starmap(connection, pre_connection_args) dp = np.abs(dpc) # Save dot products to file - with open("dpc.npy", "wb") as fich: + with open("dpc_pool.npy", "wb") as fich: np.save(fich, dpc) - fich.close() - print(" Dot products saved to file dpc.npy") + print(" Dot products saved to file dpc_pool.npy") # Save dot products modulus to file - with open("dp.npy", "wb") as fich: + with open("dp_pool.npy", "wb") as fich: np.save(fich, dp) - fich.close() - print(" Dot products modulus saved to file dp.npy") + print(" Dot products modulus saved to file dp_pool.npy") ################################################################################### # Finished From 8c31f037a4a8c85c473b62514514cebc037e0f30 Mon Sep 17 00:00:00 2001 From: fabiocfabini Date: Tue, 19 Jul 2022 11:51:26 +0100 Subject: [PATCH 2/2] Added NestablePool class --- python/NestablePool.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 python/NestablePool.py diff --git a/python/NestablePool.py b/python/NestablePool.py new file mode 100644 index 0000000..9f9cd3c --- /dev/null +++ b/python/NestablePool.py @@ -0,0 +1,22 @@ +import multiprocessing +import multiprocessing.pool + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + +# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool +# because the latter is only a wrapper function, not a proper class. +class NestablePool(multiprocessing.pool.Pool): + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super(NestablePool, self).__init__(*args, **kwargs) \ No newline at end of file