Skip to content

Commit

Permalink
Merge pull request #10 from fabiocfabini/master
Browse files Browse the repository at this point in the history
dotproduct parallelized
  • Loading branch information
ricardoribeiro-2020 authored Jul 19, 2022
2 parents ddc3e13 + e2febc2 commit 289c8f8
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 51 deletions.
22 changes: 22 additions & 0 deletions python/NestablePool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import multiprocessing
import multiprocessing.pool

class NoDaemonProcess(multiprocessing.Process):
@property
def daemon(self):
return False

@daemon.setter
def daemon(self, value):
pass


class NoDaemonContext(type(multiprocessing.get_context())):
Process = NoDaemonProcess

# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool
# because the latter is only a wrapper function, not a proper class.
class NestablePool(multiprocessing.pool.Pool):
def __init__(self, *args, **kwargs):
kwargs['context'] = NoDaemonContext()
super(NestablePool, self).__init__(*args, **kwargs)
128 changes: 77 additions & 51 deletions python/dotproduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,88 @@
This program calculates the dot product of the wfc Bloch factor with their neighbors
"""
from itertools import product
from typing import Tuple, Optional
from multiprocessing import Array, Pool

import sys
import time
import ctypes
import multiprocessing

import numpy as np

# This are the subroutines and functions
import contatempo
from headerfooter import header, footer
import loaddata as d
from NestablePool import NestablePool

import contatempo
import loaddata as d

# pylint: disable=C0103
###################################################################################
def connection(nkconn, neighborconn, dphaseconn):
def _connect(nk, j, neighbor, jNeighbor, dphase, band0, band1):
"Reads the data from file and "

with open(f"{d.wfcdirectory}k0{nk}b0{band0}.wfc", "rb") as fichconn:
wfc0 = np.load(fichconn)
with open(f"{d.wfcdirectory}k0{neighbor}b0{band1}.wfc", "rb") as fichconn:
wfc1 = np.load(fichconn)

dpc[nk, j, band0, band1] = np.sum(dphase * wfc0 * np.conjugate(wfc1)) / d.nr
dpc[neighbor, jNeighbor, band1, band0] = np.conjugate(dpc[nk, j, band0, band1])


def pre_connection(nkconn, j, neighborconn, jNeighbor, dphaseconn):
"""Calculates the dot product of all combinations of wfc in nkconn and neighborconn."""
params = {
"nkconn": (nkconn,),
"j": (j,),
"neighborconn": (neighborconn,),
"jNeighbor": (jNeighbor,),
"dphaseconn": (dphaseconn,),
"banda0": range(d.nbnd),
"banda1": range(d.nbnd),
}

with Pool(processes=min(d.nbnd, multiprocessing.cpu_count())) as pool:
pool.starmap(_connect, product(*params.values()))


def connection(
nk: int, j: int, neighbor: int, jNeighbor: Tuple[np.ndarray]
) -> None:

dpc1 = np.zeros((d.nbnd, d.nbnd), dtype=complex)
dpc2 = np.zeros((d.nbnd, d.nbnd), dtype=complex)
dphase = d.phase[:, nk] * np.conjugate(d.phase[:, neighbor])

for banda0 in range(d.nbnd):
# reads first file for dot product
infile = d.wfcdirectory + "k0" + str(nkconn) + "b0" + str(banda0) + ".wfc"
with open(infile, "rb") as fichconn:
wfc0 = np.load(fichconn)
fichconn.close()
print(" Calculating nk = " + str(nk) + " neighbor = " + str(neighbor))
sys.stdout.flush()

for banda1 in range(d.nbnd):
# reads second file for dot product
infile = (
d.wfcdirectory + "k0" + str(neighborconn) + "b0" + str(banda1) + ".wfc"
)
with open(infile, "rb") as fichconn:
wfc1 = np.load(fichconn)
fichconn.close()
pre_connection(nk, j, neighbor, jNeighbor, dphase)

# calculates the dot products u_1.u_2* and u_2.u_1*
dpc1[banda0, banda1] = np.sum(dphaseconn * wfc0 * np.conjugate(wfc1)) / d.nr
dpc2[banda1, banda0] = np.conjugate(dpc1[banda0, banda1])

return dpc1, dpc2
def _generate_pre_connection_args(
nk: int, j: int
) -> Optional[Tuple[int, int, int, Tuple[np.ndarray]]]:
"""Generates the arguments for the pre_connection function."""
neighbor = d.neighbors[nk, j]
if neighbor != -1 and neighbor > nk:
jNeighbor = np.where(d.neighbors[neighbor] == nk)
return (nk, j, neighbor, jNeighbor)
return None


###################################################################################
if __name__ == "__main__":
header("DOTPRODUCT", d.version, time.asctime())

STARTTIME = time.time() # Starts counting time
DPC_SIZE = d.nks * 4 * d.nbnd * d.nbnd
DPC_SHAPE = (d.nks, 4, d.nbnd, d.nbnd)
RUN_PARAMS = {
"nk_points": range(d.nks),
"num_neibhors": range(4), # TODO Fix Hardcoded value
}

# Reading data needed for the run

Expand All @@ -61,50 +95,42 @@ def connection(nkconn, neighborconn, dphaseconn):
print(" Number of bands:", d.nbnd)
print()
print(" Phases loaded")
# print(d.phase[10000,10]) # d.phase[d.nr,d.nks]
print(" Neighbors loaded")

# Finished reading data needed for the run
print()
##########################################################

dpc = np.full((d.nks, 4, d.nbnd, d.nbnd), 0 + 0j, dtype=complex)
dp = np.zeros((d.nks, 4, d.nbnd, d.nbnd))
# Creating a buffer for the dpc np.ndarray
dpc_base = Array(ctypes.c_double, 2 * DPC_SIZE, lock=False)
# Initializing shared instance of np.ndarray 'dpc'
dpc = np.frombuffer(dpc_base, dtype=complex).reshape(DPC_SHAPE)

for nk in range(d.nks): # runs through all k-points
for j in range(4): # runs through all neighbors
neighbor = d.neighbors[nk, j]

if neighbor != -1 and neighbor > nk: # exclude invalid neighbors
jNeighbor = np.where(d.neighbors[neighbor] == nk)
# Calculates the diference in phases to convert \psi to u
dphase = d.phase[:, nk] * np.conjugate(d.phase[:, neighbor])

print(
" Calculating nk = "
+ str(nk)
+ " neighbor = "
+ str(neighbor)
)
sys.stdout.flush()
##########################################################

dpc[nk, j, :, :], dpc[neighbor, jNeighbor, :, :] = connection(
nk, neighbor, dphase
)
# Creating a list of tuples with the neighbors of each k-point
with NestablePool(d.npr) as pool:
pre_connection_args = (
filter(
None,
pool.starmap(
_generate_pre_connection_args, product(*RUN_PARAMS.values())
),
),
)
pool.starmap(connection, pre_connection_args)

dp = np.abs(dpc)

# Save dot products to file
with open("dpc.npy", "wb") as fich:
with open("dpc_pool.npy", "wb") as fich:
np.save(fich, dpc)
fich.close()
print(" Dot products saved to file dpc.npy")
print(" Dot products saved to file dpc_pool.npy")

# Save dot products modulus to file
with open("dp.npy", "wb") as fich:
with open("dp_pool.npy", "wb") as fich:
np.save(fich, dp)
fich.close()
print(" Dot products modulus saved to file dp.npy")
print(" Dot products modulus saved to file dp_pool.npy")

###################################################################################
# Finished
Expand Down

0 comments on commit 289c8f8

Please sign in to comment.