Skip to content

Commit

Permalink
Update docstrings to numpydoc format in dos.py, qlms.py, solver/base.…
Browse files Browse the repository at this point in the history
…py, and solver/perf.py
  • Loading branch information
k-yoshimi committed Nov 11, 2024
1 parent aa3701c commit 85ae6f1
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 32 deletions.
90 changes: 86 additions & 4 deletions src/hwave/dos.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""Density of states (DoS) calculation and plotting utilities.
This module provides functionality for calculating, plotting and analyzing the
density of states from Hartree-Fock calculations.
"""

from __future__ import annotations

import itertools
Expand All @@ -9,10 +16,26 @@


class DoS:
dos: np.ndarray
ene: np.ndarray
ene_num: int
norb: int
"""Class for storing and manipulating density of states data.
Parameters
----------
ene : np.ndarray
Energy grid points
dos : np.ndarray
Density of states values for each orbital at each energy point
Attributes
----------
dos : np.ndarray
Density of states array with shape (norb, nene)
ene : np.ndarray
Energy grid points array with shape (nene,)
ene_num : int
Number of energy points
norb : int
Number of orbitals
"""

def __init__(self, ene: np.ndarray, dos: np.ndarray):
assert ene.shape[0] == dos.shape[1]
Expand All @@ -22,6 +45,17 @@ def __init__(self, ene: np.ndarray, dos: np.ndarray):
self.norb = dos.shape[0]

def plot(self, filename: str = "", verbose: bool = False):
"""Plot the density of states.
Creates a plot showing the total DoS and orbital-resolved DoS.
Parameters
----------
filename : str, optional
If provided, save plot to this file
verbose : bool, optional
If True, print additional output
"""
try:
import matplotlib.pyplot as plt
except ImportError:
Expand All @@ -45,6 +79,15 @@ def plot(self, filename: str = "", verbose: bool = False):
plt.close()

def write_dos(self, output: str, verbose: bool = False):
"""Write density of states data to file.
Parameters
----------
output : str
Output filename
verbose : bool, optional
If True, print additional output
"""
if verbose:
print("Writing DOS to file: ", output)
total_dos = np.sum(self.dos, axis=0)
Expand All @@ -62,6 +105,18 @@ def write_dos(self, output: str, verbose: bool = False):


def __read_geom(file_name="./dir-model/zvo_geom.dat"):
"""Read geometry data from file.
Parameters
----------
file_name : str, optional
Path to geometry file
Returns
-------
np.ndarray
Unit cell vectors array with shape (3,3)
"""
with open(file_name, "r") as fr:
uvec = np.zeros((3, 3))
for i, line in enumerate(itertools.islice(fr, 3)): # take first 3 lines
Expand All @@ -75,6 +130,29 @@ def calc_dos(
ene_num: int = 101,
verbose: bool = False,
) -> DoS:
"""Calculate density of states.
Parameters
----------
input_dict : dict
Input parameters dictionary
ene_window : list, optional
Energy window [min, max] for DoS calculation
ene_num : int, optional
Number of energy points
verbose : bool, optional
If True, print additional output
Returns
-------
DoS
Calculated density of states object
Raises
------
ImportError
If required libtetrabz package is not installed
"""
try:
import libtetrabz
except ImportError:
Expand Down Expand Up @@ -130,6 +208,10 @@ def calc_dos(


def main():
"""Command-line interface for DoS calculation.
Parses command line arguments and runs DoS calculation.
"""
import tomli
import argparse

Expand Down
65 changes: 47 additions & 18 deletions src/hwave/qlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,88 +16,108 @@
from requests.structures import CaseInsensitiveDict

def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None):
"""Run Hartree-Fock calculation with given input parameters.
Parameters
----------
input_dict : dict, optional
Dictionary containing input parameters. Cannot be used with input_file.
input_file : str, optional
Path to TOML input file. Cannot be used with input_dict.
Raises
------
RuntimeError
If neither or both input_dict and input_file are provided
"""
# Validate input arguments
if input_dict is None:
if input_file is None:
raise RuntimeError("Neither input_dict nor input_file are passed")
# Load parameters from TOML file
with open(input_file, "rb") as f:
input_dict = tomli.load(f)
else:
if input_file is not None:
raise RuntimeError("Both input_dict and input_file are passed")

# Initialize information about log
# Initialize logging configuration
info_log = input_dict.get("log", {})
info_log["print_level"] = info_log.get("print_level", 1)
info_log["print_step"] = info_log.get("print_step", 1)
info_log["print_level"] = info_log.get("print_level", 1) # Default print level
info_log["print_step"] = info_log.get("print_step", 1) # Default print step

# Initialize information about mode
# Get calculation mode and file paths
info_mode = input_dict.get("mode", {})
info_file = input_dict.get("file", {"input": {}, "output": {}})
# Initialize information about input files

# Setup input file paths
info_inputfile = info_file.get("input", {})
info_inputfile["path_to_input"] = info_inputfile.get("path_to_input", "")

# Initialize information about output files
# Setup output directory
info_outputfile = info_file.get("output", {})
info_outputfile["path_to_output"] = info_outputfile.get("path_to_output", "output")
path_to_output = info_outputfile["path_to_output"]
os.makedirs(path_to_output, exist_ok=True)

# Configure logging
logger = logging.getLogger("qlms")
fmt = "%(asctime)s %(levelname)s %(name)s: %(message)s"
# logging.basicConfig(level=logging.DEBUG, format=fmt)
logging.basicConfig(level=logging.INFO, format=fmt)

# Validate calculation mode
if "mode" not in info_mode:
logger.error("mode is not defined in [mode].")
exit(1)
mode = info_mode["mode"]

# Initialize solver based on calculation mode
if mode == "UHFr":
# Real-space unrestricted Hartree-Fock
logger.info("Read def files")
file_list = CaseInsensitiveDict()
#interaction files

# Process interaction file paths
for key, file_name in info_inputfile["interaction"].items():
if key.lower() in ["trans", "coulombinter", "coulombintra", "pairhop", "hund", "exchange", "ising", "pairlift", "interall"]:
if key.lower() in ["trans", "coulombinter", "coulombintra", "pairhop",
"hund", "exchange", "ising", "pairlift", "interall"]:
file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name)
else:
logging.error("Keyword {} is incorrect.".format(key))
exit(1)
#initial and green

# Process initial state and Green's function files
for key, file_name in info_inputfile.items():
if key.lower() == "initial":
file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name)
elif key.lower() == "onebodyg":
file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name)
read_io = qlmsio.read_input.QLMSInput(file_list)

# Read Hamiltonian and Green's function parameters
logger.info("Get Hamiltonian information")
ham_info = read_io.get_param("ham")

logger.info("Get Green function information")
green_info = read_io.get_param("green")

# solver = sol_uhf.UHF(ham_info, info_log, info_mode, mod_param_info)
solver = sol_uhfr.UHFr(ham_info, info_log, info_mode)

elif mode == "UHFk":
# k-space unrestricted Hartree-Fock
logger.info("Read definitions from files")
read_io = qlmsio.read_input_k.QLMSkInput(info_inputfile)

# logger.info("Get parameter information")
# mod_param_info = info_mode #read_io.get_param("mod")
# pprint.pprint(mod_param_info, width = 1)

logger.info("Get Hamiltonian information")
ham_info = read_io.get_param("ham")

logger.info("Get Green function information")
green_info = read_io.get_param("green")

# pprint.pprint(info_mode, width=1)

solver = sol_uhfk.UHFk(ham_info, info_log, info_mode)

elif mode == "RPA":
# Random Phase Approximation
logger.info("RPA mode")

logger.info("Read interaction definitions from files")
Expand All @@ -107,12 +127,13 @@ def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None):
solver = sol_rpa.RPA(ham_info, info_log, info_mode)

green_info = read_io.get_param("green")
green_info.update( solver.read_init(info_inputfile) )
green_info.update(solver.read_init(info_inputfile))

else:
logger.warning("mode is incorrect: mode={}.".format(mode))
exit(0)

# Execute calculation
logger.info("Start UHF calculation")
solver.solve(green_info, path_to_output)
logger.info("Save calculation results.")
Expand All @@ -121,20 +142,28 @@ def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None):


def main():
"""Command-line interface entry point.
Parses command line arguments and runs the calculation.
"""
import argparse

# Setup argument parser
parser = argparse.ArgumentParser(prog='hwave')
parser.add_argument('input_toml', nargs='?', default=None, help='input parameter file')
parser.add_argument('--version', action='store_true', help='show version')

args = parser.parse_args()

# Handle version request
if args.version:
print('hwave', hwave.__version__)
sys.exit(0)

# Validate input file
if args.input_toml is None:
parser.print_help()
sys.exit(1)

# Run calculation
run(input_file = args.input_toml)
Loading

0 comments on commit 85ae6f1

Please sign in to comment.