diff --git a/src/hwave/dos.py b/src/hwave/dos.py index db7bd61..0c2eefe 100644 --- a/src/hwave/dos.py +++ b/src/hwave/dos.py @@ -1,3 +1,10 @@ +"""Density of states (DoS) calculation and plotting utilities. + +This module provides functionality for calculating, plotting and analyzing the +density of states from Hartree-Fock calculations. + +""" + from __future__ import annotations import itertools @@ -9,10 +16,26 @@ class DoS: - dos: np.ndarray - ene: np.ndarray - ene_num: int - norb: int + """Class for storing and manipulating density of states data. + + Parameters + ---------- + ene : np.ndarray + Energy grid points + dos : np.ndarray + Density of states values for each orbital at each energy point + + Attributes + ---------- + dos : np.ndarray + Density of states array with shape (norb, nene) + ene : np.ndarray + Energy grid points array with shape (nene,) + ene_num : int + Number of energy points + norb : int + Number of orbitals + """ def __init__(self, ene: np.ndarray, dos: np.ndarray): assert ene.shape[0] == dos.shape[1] @@ -22,6 +45,17 @@ def __init__(self, ene: np.ndarray, dos: np.ndarray): self.norb = dos.shape[0] def plot(self, filename: str = "", verbose: bool = False): + """Plot the density of states. + + Creates a plot showing the total DoS and orbital-resolved DoS. + + Parameters + ---------- + filename : str, optional + If provided, save plot to this file + verbose : bool, optional + If True, print additional output + """ try: import matplotlib.pyplot as plt except ImportError: @@ -45,6 +79,15 @@ def plot(self, filename: str = "", verbose: bool = False): plt.close() def write_dos(self, output: str, verbose: bool = False): + """Write density of states data to file. + + Parameters + ---------- + output : str + Output filename + verbose : bool, optional + If True, print additional output + """ if verbose: print("Writing DOS to file: ", output) total_dos = np.sum(self.dos, axis=0) @@ -62,6 +105,18 @@ def write_dos(self, output: str, verbose: bool = False): def __read_geom(file_name="./dir-model/zvo_geom.dat"): + """Read geometry data from file. + + Parameters + ---------- + file_name : str, optional + Path to geometry file + + Returns + ------- + np.ndarray + Unit cell vectors array with shape (3,3) + """ with open(file_name, "r") as fr: uvec = np.zeros((3, 3)) for i, line in enumerate(itertools.islice(fr, 3)): # take first 3 lines @@ -75,6 +130,29 @@ def calc_dos( ene_num: int = 101, verbose: bool = False, ) -> DoS: + """Calculate density of states. + + Parameters + ---------- + input_dict : dict + Input parameters dictionary + ene_window : list, optional + Energy window [min, max] for DoS calculation + ene_num : int, optional + Number of energy points + verbose : bool, optional + If True, print additional output + + Returns + ------- + DoS + Calculated density of states object + + Raises + ------ + ImportError + If required libtetrabz package is not installed + """ try: import libtetrabz except ImportError: @@ -130,6 +208,10 @@ def calc_dos( def main(): + """Command-line interface for DoS calculation. + + Parses command line arguments and runs DoS calculation. + """ import tomli import argparse diff --git a/src/hwave/qlms.py b/src/hwave/qlms.py index c7b5068..a463183 100644 --- a/src/hwave/qlms.py +++ b/src/hwave/qlms.py @@ -16,53 +16,77 @@ from requests.structures import CaseInsensitiveDict def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None): + """Run Hartree-Fock calculation with given input parameters. + + Parameters + ---------- + input_dict : dict, optional + Dictionary containing input parameters. Cannot be used with input_file. + input_file : str, optional + Path to TOML input file. Cannot be used with input_dict. + + Raises + ------ + RuntimeError + If neither or both input_dict and input_file are provided + """ + # Validate input arguments if input_dict is None: if input_file is None: raise RuntimeError("Neither input_dict nor input_file are passed") + # Load parameters from TOML file with open(input_file, "rb") as f: input_dict = tomli.load(f) else: if input_file is not None: raise RuntimeError("Both input_dict and input_file are passed") - # Initialize information about log + # Initialize logging configuration info_log = input_dict.get("log", {}) - info_log["print_level"] = info_log.get("print_level", 1) - info_log["print_step"] = info_log.get("print_step", 1) + info_log["print_level"] = info_log.get("print_level", 1) # Default print level + info_log["print_step"] = info_log.get("print_step", 1) # Default print step - # Initialize information about mode + # Get calculation mode and file paths info_mode = input_dict.get("mode", {}) info_file = input_dict.get("file", {"input": {}, "output": {}}) - # Initialize information about input files + + # Setup input file paths info_inputfile = info_file.get("input", {}) info_inputfile["path_to_input"] = info_inputfile.get("path_to_input", "") - # Initialize information about output files + # Setup output directory info_outputfile = info_file.get("output", {}) info_outputfile["path_to_output"] = info_outputfile.get("path_to_output", "output") path_to_output = info_outputfile["path_to_output"] os.makedirs(path_to_output, exist_ok=True) + # Configure logging logger = logging.getLogger("qlms") fmt = "%(asctime)s %(levelname)s %(name)s: %(message)s" - # logging.basicConfig(level=logging.DEBUG, format=fmt) logging.basicConfig(level=logging.INFO, format=fmt) + # Validate calculation mode if "mode" not in info_mode: logger.error("mode is not defined in [mode].") exit(1) mode = info_mode["mode"] + + # Initialize solver based on calculation mode if mode == "UHFr": + # Real-space unrestricted Hartree-Fock logger.info("Read def files") file_list = CaseInsensitiveDict() - #interaction files + + # Process interaction file paths for key, file_name in info_inputfile["interaction"].items(): - if key.lower() in ["trans", "coulombinter", "coulombintra", "pairhop", "hund", "exchange", "ising", "pairlift", "interall"]: + if key.lower() in ["trans", "coulombinter", "coulombintra", "pairhop", + "hund", "exchange", "ising", "pairlift", "interall"]: file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name) else: logging.error("Keyword {} is incorrect.".format(key)) exit(1) - #initial and green + + # Process initial state and Green's function files for key, file_name in info_inputfile.items(): if key.lower() == "initial": file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name) @@ -70,34 +94,30 @@ def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None): file_list[key] = os.path.join(info_inputfile["path_to_input"], file_name) read_io = qlmsio.read_input.QLMSInput(file_list) + # Read Hamiltonian and Green's function parameters logger.info("Get Hamiltonian information") ham_info = read_io.get_param("ham") logger.info("Get Green function information") green_info = read_io.get_param("green") - # solver = sol_uhf.UHF(ham_info, info_log, info_mode, mod_param_info) solver = sol_uhfr.UHFr(ham_info, info_log, info_mode) elif mode == "UHFk": + # k-space unrestricted Hartree-Fock logger.info("Read definitions from files") read_io = qlmsio.read_input_k.QLMSkInput(info_inputfile) - # logger.info("Get parameter information") - # mod_param_info = info_mode #read_io.get_param("mod") - # pprint.pprint(mod_param_info, width = 1) - logger.info("Get Hamiltonian information") ham_info = read_io.get_param("ham") logger.info("Get Green function information") green_info = read_io.get_param("green") - # pprint.pprint(info_mode, width=1) - solver = sol_uhfk.UHFk(ham_info, info_log, info_mode) elif mode == "RPA": + # Random Phase Approximation logger.info("RPA mode") logger.info("Read interaction definitions from files") @@ -107,12 +127,13 @@ def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None): solver = sol_rpa.RPA(ham_info, info_log, info_mode) green_info = read_io.get_param("green") - green_info.update( solver.read_init(info_inputfile) ) + green_info.update(solver.read_init(info_inputfile)) else: logger.warning("mode is incorrect: mode={}.".format(mode)) exit(0) + # Execute calculation logger.info("Start UHF calculation") solver.solve(green_info, path_to_output) logger.info("Save calculation results.") @@ -121,20 +142,28 @@ def run(*, input_dict: Optional[dict] = None, input_file: Optional[str] = None): def main(): + """Command-line interface entry point. + + Parses command line arguments and runs the calculation. + """ import argparse + # Setup argument parser parser = argparse.ArgumentParser(prog='hwave') parser.add_argument('input_toml', nargs='?', default=None, help='input parameter file') parser.add_argument('--version', action='store_true', help='show version') args = parser.parse_args() + # Handle version request if args.version: print('hwave', hwave.__version__) sys.exit(0) + # Validate input file if args.input_toml is None: parser.print_help() sys.exit(1) + # Run calculation run(input_file = args.input_toml) diff --git a/src/hwave/solver/base.py b/src/hwave/solver/base.py index 71780eb..1568e1d 100644 --- a/src/hwave/solver/base.py +++ b/src/hwave/solver/base.py @@ -1,11 +1,43 @@ +"""Base solver class for Hartree-Fock calculations. + +This module provides the base solver class that implements common functionality +for Hartree-Fock calculations, including parameter handling and validation. + +""" import sys from requests.structures import CaseInsensitiveDict -# from pprint import pprint - import logging logger = logging.getLogger("qlms").getChild("solver") + class solver_base(): + """Base solver class for Hartree-Fock calculations. + + Parameters + ---------- + param_ham : dict + Hamiltonian parameters + info_log : dict + Logging configuration + info_mode : dict + Calculation mode parameters + param_mod : dict, optional + Model parameters to override defaults + + Attributes + ---------- + param_mod : CaseInsensitiveDict + Model parameters + param_ham : dict + Hamiltonian parameters + info_log : dict + Logging configuration + threshold : float + Cutoff threshold for Green's function elements + relax_checks : bool + Whether to relax parameter validation checks + """ + def __init__(self, param_ham, info_log, info_mode, param_mod=None): logger = logging.getLogger("qlms").getChild(self.name) @@ -32,12 +64,8 @@ def __init__(self, param_ham, info_log, info_mode, param_mod=None): range_list = { "T": [ 0, None ], - # "2Sz": [ -param_mod["Nsite"], param_mod["Nsite"] ], - # "Nsite": [ 1, None ], - # "Ncond": [ 1, None ], "IterationMax": [ 0, None ], "Mix": [ 0.0, 1.0 ], - # "print_step": [ 1, None ], "EPS": [ 0, None ], } @@ -114,15 +142,23 @@ def __init__(self, param_ham, info_log, info_mode, param_mod=None): # canonicalize self.param_mod["EPS"] = pow(10, -self.param_mod["EPS"]) - # debug - # pprint(self.param_mod) - def _check_info_mode(self, info_mode): + """Check validity of info_mode parameters. + + Parameters + ---------- + info_mode : dict + Mode parameters to validate + + Returns + ------- + int + Number of validation errors found + """ logger = logging.getLogger("qlms").getChild(self.name) fix_list = { "mode": ["UHFr", "UHFk"], - # "flag_fock": [True, False] } exit_code = 0 @@ -136,6 +172,18 @@ def _check_info_mode(self, info_mode): return exit_code def _check_param_mod(self, param_mod): + """Check validity of model parameters. + + Parameters + ---------- + param_mod : dict + Model parameters to validate + + Returns + ------- + int + Number of validation errors found + """ logger = logging.getLogger("qlms").getChild(self.name) error_code = 0 @@ -149,6 +197,20 @@ def _check_param_mod(self, param_mod): return error_code def _check_param_range(self, param_mod, range_list): + """Check if parameters are within valid ranges. + + Parameters + ---------- + param_mod : dict + Model parameters to validate + range_list : dict + Valid ranges for parameters + + Returns + ------- + int + Number of validation errors found + """ logger = logging.getLogger("qlms").getChild(self.name) error_code = 0 @@ -166,6 +228,25 @@ def _check_param_range(self, param_mod, range_list): return error_code def _round_to_int(self, val, mode): + """Round a value to integer according to specified mode. + + Parameters + ---------- + val : float + Value to round + mode : str + Rounding mode to use + + Returns + ------- + int + Rounded integer value + + Raises + ------ + SystemExit + If rounding fails or mode is invalid + """ import math mode = mode.lower() # case-insensitive if mode == "as-is": @@ -197,10 +278,33 @@ def _round_to_int(self, val, mode): return ret def solve(self, path_to_output): + """Solve the Hartree-Fock equations. + + Parameters + ---------- + path_to_output : str + Path to output file + """ pass def get_results(self): + """Get calculation results. + + Returns + ------- + tuple + (physics, Green's function) results + """ return (self.physics, self.Green) def save_results(self, info_outputfile, green_info): + """Save calculation results. + + Parameters + ---------- + info_outputfile : dict + Output file configuration + green_info : dict + Green's function information + """ pass diff --git a/src/hwave/solver/perf.py b/src/hwave/solver/perf.py index b7c9677..5ee0bbc 100644 --- a/src/hwave/solver/perf.py +++ b/src/hwave/solver/perf.py @@ -1,11 +1,34 @@ +"""Performance profiling utilities. + +This module provides utilities for profiling function execution times and +collecting performance statistics. +""" + from functools import wraps import time + class PerfDB: + """Database for storing performance profiling data. + + Stores execution counts and total elapsed time for profiled functions. + Prints summary statistics on deletion. + + Attributes + ---------- + _db_count : dict + Number of calls per function + _db_value : dict + Total elapsed time per function + """ + def __init__(self): + """Initialize empty performance database.""" self._db_count = {} self._db_value = {} + def __del__(self): + """Print summary statistics when object is deleted.""" if len(self._db_count) == 0: return print("--------------------------------------------------------------------------------") @@ -20,13 +43,37 @@ def __del__(self): self._db_count[item] )) print("--------------------------------------------------------------------------------") + def put(self, name, value): + """Add a timing measurement. + + Parameters + ---------- + name : str + Function name + value : float + Elapsed time in seconds + """ self._db_count[name] = self._db_count.get(name, 0) + 1 self._db_value[name] = self._db_value.get(name, 0) + value + _perf_db_data = PerfDB() + def do_profile(func): + """Decorator for profiling function execution time. + + Parameters + ---------- + func : callable + Function to profile + + Returns + ------- + callable + Wrapped function that measures and records execution time + """ @wraps(func) def wrapper(*args, **kwargs): # start time