From 81af496a4a06e974605e23910897f8a1b4829037 Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Tue, 28 May 2024 12:08:15 +0200 Subject: [PATCH] Ruff formatted python modules --- src/pyssmf/hopping_pruning.py | 4 +- src/pyssmf/input.py | 102 +++++++++++++++++----------------- src/pyssmf/parsing.py | 52 ++++++++--------- src/pyssmf/runner.py | 48 ++++++++-------- src/pyssmf/schema.py | 24 ++++---- src/pyssmf/tb_hamiltonian.py | 18 +++--- src/pyssmf/utils.py | 6 +- src/pyssmf/visualization.py | 30 +++++----- tests/conftest.py | 2 +- tests/test_parsing.py | 4 +- tests/test_schema.py | 2 +- 11 files changed, 146 insertions(+), 146 deletions(-) diff --git a/src/pyssmf/hopping_pruning.py b/src/pyssmf/hopping_pruning.py index 80ee565..51486c2 100644 --- a/src/pyssmf/hopping_pruning.py +++ b/src/pyssmf/hopping_pruning.py @@ -60,7 +60,7 @@ def prune_by_threshold( # problems for hoppings structures which are not so trivial. if not self.max_value and not self.hopping_matrix_norms: self.logger.warning( - "Could not extract the hopping_matrix norms and their max_value." + 'Could not extract the hopping_matrix norms and their max_value.' ) return threshold = threshold_factor * self.max_value @@ -88,7 +88,7 @@ def prune_by_threshold( ] except Exception: self.logger.warning( - "Could not update the model parameters after pruning." + 'Could not update the model parameters after pruning.' ) self.update_norms_and_max_value() diff --git a/src/pyssmf/input.py b/src/pyssmf/input.py index 1307c71..6189bfe 100644 --- a/src/pyssmf/input.py +++ b/src/pyssmf/input.py @@ -27,10 +27,10 @@ class ValidLatticeModels(ABC): def __init__(self, logger: logging.Logger = logging.getLogger(__name__)): self.logger = logger self._valid_lattice_models = [ - "linear", - "square", - "honeycomb", - "triangular", + 'linear', + 'square', + 'honeycomb', + 'triangular', ] # TODO extend this @@ -64,34 +64,34 @@ def __init__(self, **kwargs): """ super().__init__() # Initializing data - data = {"code": "SSMF"} + data = {'code': 'SSMF'} # Check working_directory and stores it in data - if not kwargs.get("working_directory"): + if not kwargs.get('working_directory'): self.logger.error( - "Could not find specified the working_directory in the input." + 'Could not find specified the working_directory in the input.' ) return - data["working_directory"] = kwargs.get("working_directory") + data['working_directory'] = kwargs.get('working_directory') # Read from input file if provided in the argument - read_input_file = kwargs.get("read_from_input_file", False) + read_input_file = kwargs.get('read_from_input_file', False) if read_input_file: - data["input_file"] = kwargs.get("input_file") + data['input_file'] = kwargs.get('input_file') input_file = os.path.join( - kwargs.get("working_directory"), kwargs.get("input_file") + kwargs.get('working_directory'), kwargs.get('input_file') ) - data["input_data"] = self.read_from_file(input_file) + data['input_data'] = self.read_from_file(input_file) else: # We check whether a model_file has been defined or a model_label - if kwargs.get("tb_model_file"): - data["tb_model"] = kwargs.get("tb_model_file") + if kwargs.get('tb_model_file'): + data['tb_model'] = kwargs.get('tb_model_file') # pruning only applies for tb_model_file cases - if kwargs.get("pruning", False): - data["prune_threshold"] = kwargs.get("prune_threshold", 0.01) - elif kwargs.get("lattice_model") in self._valid_lattice_models: - data["tb_model"] = kwargs.get("lattice_model") - hoppings = kwargs.get("hoppings", []) + if kwargs.get('pruning', False): + data['prune_threshold'] = kwargs.get('prune_threshold', 0.01) + elif kwargs.get('lattice_model') in self._valid_lattice_models: + data['tb_model'] = kwargs.get('lattice_model') + hoppings = kwargs.get('hoppings', []) # We check if 'hoppings' was empty if len(hoppings) == 0: n_hoppings = 1 @@ -101,8 +101,8 @@ def __init__(self, **kwargs): n_hoppings = len(hoppings) if n_hoppings > 3: self.logger.error( - "Maximum n_hoppings models supported is 3. Please, select " - "a smaller number." + 'Maximum n_hoppings models supported is 3. Please, select ' + 'a smaller number.' ) return n_orbitals = len(hoppings[0]) @@ -112,40 +112,40 @@ def __init__(self, **kwargs): for hop_point in hoppings ): self.logger.error( - "Dimensions of each hopping matrix do not coincide with" - "(n_orbitals, n_orbitals).", - data={"n_orbitals": n_orbitals}, + 'Dimensions of each hopping matrix do not coincide with' + '(n_orbitals, n_orbitals).', + data={'n_orbitals': n_orbitals}, ) # TODO improve this - onsite_energies = kwargs.get("onsite_energies", []) + onsite_energies = kwargs.get('onsite_energies', []) if len(onsite_energies) == 0: onsite_energies = [0.0] * n_orbitals - data["hoppings"] = hoppings - data["onsite_energies"] = onsite_energies - data["n_hoppings"] = n_hoppings - data["n_orbitals"] = n_orbitals + data['hoppings'] = hoppings + data['onsite_energies'] = onsite_energies + data['n_hoppings'] = n_hoppings + data['n_orbitals'] = n_orbitals else: self.logger.error( - "Could not find the initial model. Please, check your inputs: " - "1) define `model_file` pointing to your Wannier90 `*_hr.dat` " - "hoppings file, or 2) specify the `lattice_model` to study among the " - "accepted values.", - data={"lattice_model": self._valid_lattice_models}, + 'Could not find the initial model. Please, check your inputs: ' + '1) define `model_file` pointing to your Wannier90 `*_hr.dat` ' + 'hoppings file, or 2) specify the `lattice_model` to study among the ' + 'accepted values.', + data={'lattice_model': self._valid_lattice_models}, ) # KGrids # For band structure calculations - data["n_k_path"] = kwargs.get("n_k_path", 90) + data['n_k_path'] = kwargs.get('n_k_path', 90) # For full_bz diagonalization - data["k_grid"] = kwargs.get("k_grid", [1, 1, 1]) + data['k_grid'] = kwargs.get('k_grid', [1, 1, 1]) # Plotting arguments - data["plot_hoppings"] = kwargs.get("plot_hoppings", False) - data["plot_bands"] = kwargs.get("plot_bands", False) + data['plot_hoppings'] = kwargs.get('plot_hoppings', False) + data['plot_bands'] = kwargs.get('plot_bands', False) # DOS calculation and plotting - data["dos"] = kwargs.get("dos", False) - data["dos_gaussian_width"] = kwargs.get("dos_gaussian_width", 0.1) - data["dos_delta_energy"] = kwargs.get("dos_delta_energy", 0.01) + data['dos'] = kwargs.get('dos', False) + data['dos_gaussian_width'] = kwargs.get('dos_gaussian_width', 0.1) + data['dos_delta_energy'] = kwargs.get('dos_delta_energy', 0.01) self.data = data self.to_json() @@ -153,7 +153,7 @@ def to_json(self): """ Stores the input data in a JSON file in the working_directory. """ - with open(f"{self.data.get('working_directory')}/input_ssmf.json", "w") as file: + with open(f"{self.data.get('working_directory')}/input_ssmf.json", 'w') as file: json.dump(self.data, file, indent=4) def read_from_file(self, input_file: str) -> dict: @@ -167,22 +167,22 @@ def read_from_file(self, input_file: str) -> dict: (dict): dictionary with the input data read from the JSON input file. """ try: - with open(input_file, "r") as file: + with open(input_file, 'r') as file: input_data = json.load(file) except FileNotFoundError: self.logger.error( - "Input file not found.", - extra={"input_file": input_file}, + 'Input file not found.', + extra={'input_file': input_file}, ) except json.JSONDecodeError: self.logger.error( - "Failed to decode JSON in input file.", - extra={"input_file": input_file}, + 'Failed to decode JSON in input file.', + extra={'input_file': input_file}, ) - code_name = input_data.get("code", "") - if code_name != "SSMF": + code_name = input_data.get('code', '') + if code_name != 'SSMF': self.logger.error( - "Could not recognize the input JSON file as readable by the SSMF code.", - extra={"input_file": input_file}, + 'Could not recognize the input JSON file as readable by the SSMF code.', + extra={'input_file': input_file}, ) return input_data diff --git a/src/pyssmf/parsing.py b/src/pyssmf/parsing.py index b8155c6..f70fda2 100644 --- a/src/pyssmf/parsing.py +++ b/src/pyssmf/parsing.py @@ -26,7 +26,7 @@ from .schema import System, BravaisLattice, Model from .utils import get_files -re_n = r"[\n\r]" +re_n = r'[\n\r]' class WOutParser(TextParser): @@ -40,19 +40,19 @@ def __init__(self): def init_quantities(self): structure_quantities = [ - Quantity("labels", r"\|\s*([A-Z][a-z]*)", repeats=True), + Quantity('labels', r'\|\s*([A-Z][a-z]*)', repeats=True), Quantity( - "positions", - r"\|\s*([\-\d\.]+)\s*([\-\d\.]+)\s*([\-\d\.]+)", + 'positions', + r'\|\s*([\-\d\.]+)\s*([\-\d\.]+)\s*([\-\d\.]+)', repeats=True, ), ] self._quantities = [ - Quantity("lattice_vectors", r"\s*a_\d\s*([\d\-\s\.]+)", repeats=True), + Quantity('lattice_vectors', r'\s*a_\d\s*([\d\-\s\.]+)', repeats=True), Quantity( - "structure", - rf"(\s*Fractional Coordinate[\s\S]+?)(?:{re_n}\s*(PROJECTIONS|K-POINT GRID))", + 'structure', + rf'(\s*Fractional Coordinate[\s\S]+?)(?:{re_n}\s*(PROJECTIONS|K-POINT GRID))', repeats=False, sub_parser=TextParser(quantities=structure_quantities), ), @@ -70,8 +70,8 @@ def __init__(self): def init_quantities(self): self._quantities = [ - Quantity("degeneracy_factors", r"\s*written on[\s\w]*:\d*:\d*\s*([\d\s]+)"), - Quantity("hoppings", r"\s*([-\d\s.]+)", repeats=False), + Quantity('degeneracy_factors', r'\s*written on[\s\w]*:\d*:\d*\s*([\d\s]+)'), + Quantity('hoppings', r'\s*([-\d\s.]+)', repeats=False), ] @@ -95,19 +95,19 @@ def parse_system(self): """ sec_system = self.model.m_create(BravaisLattice).m_create(System) - structure = self.wout_parser.get("structure") + structure = self.wout_parser.get('structure') if structure is None: - self.logger.error("Error parsing the structure from .wout") + self.logger.error('Error parsing the structure from .wout') return - if self.wout_parser.get("lattice_vectors", []): + if self.wout_parser.get('lattice_vectors', []): lattice_vectors = np.vstack( - self.wout_parser.get("lattice_vectors", [])[-3:] + self.wout_parser.get('lattice_vectors', [])[-3:] ) sec_system.lattice_vectors = lattice_vectors * ureg.angstrom sec_system.periodic = [True, True, True] - sec_system.labels = structure.get("labels") - if structure.get("positions") is not None: - sec_system.positions = structure.get("positions") * ureg.angstrom + sec_system.labels = structure.get('labels') + if structure.get('positions') is not None: + sec_system.positions = structure.get('positions') * ureg.angstrom def parse_hoppings(self): """ @@ -115,8 +115,8 @@ def parse_hoppings(self): """ bravais_lattice = self.model.bravais_lattice - deg_factors = self.hr_parser.get("degeneracy_factors", []) - full_hoppings = self.hr_parser.get("hoppings", []) + deg_factors = self.hr_parser.get('degeneracy_factors', []) + full_hoppings = self.hr_parser.get('hoppings', []) if deg_factors is not None and full_hoppings is not None: n_orbitals = deg_factors[0] n_points = deg_factors[1] @@ -158,9 +158,9 @@ def parse(self, filepath: str, model: Model, logger: logging.Logger = None): logger (logging.Logger, optional): Logger object for debug messages. Defaults to None. """ basename = os.path.basename(filepath) # Getting filepath for *_hr.dat file - wout_files = get_files("*.wout", filepath, basename) + wout_files = get_files('*.wout', filepath, basename) if len(wout_files) > 1: - logger.warning("Multiple `*.wout` files found; we will parse the last one.") + logger.warning('Multiple `*.wout` files found; we will parse the last one.') mainfile = wout_files[-1] # Path to *.wout file self.filepath = filepath @@ -195,7 +195,7 @@ def __init__(self): def _bravais_vectors(self, n_neighbors, lattice_vectors): bravais_vectors = [] - if self.lattice_model == "linear": + if self.lattice_model == 'linear': for i in range(-n_neighbors, n_neighbors + 1): j = 0 k = 0 @@ -215,11 +215,11 @@ def linear(self): system = bravais_lattice.m_create(System) lattice_vectors = [[1, 0, 0], [0, 0, 0], [0, 0, 0]] _system_map = { - "lattice_vectors": lattice_vectors, - "periodic": [True, False, False], - "n_atoms": 1, - "labels": ["X"], - "positions": [[0, 0, 0]], + 'lattice_vectors': lattice_vectors, + 'periodic': [True, False, False], + 'n_atoms': 1, + 'labels': ['X'], + 'positions': [[0, 0, 0]], } for key in system.m_def.all_quantities.keys(): system.m_set(system.m_get_quantity_definition(key), _system_map.get(key)) diff --git a/src/pyssmf/runner.py b/src/pyssmf/runner.py index 93dcb45..2036442 100644 --- a/src/pyssmf/runner.py +++ b/src/pyssmf/runner.py @@ -34,7 +34,7 @@ class Runner(ValidLatticeModels): def __init__(self, **kwargs): super().__init__() - self.data = kwargs.get("data", {}) + self.data = kwargs.get('data', {}) # Initializing the model class self.model = Model() @@ -43,49 +43,49 @@ def parse_tb_model(self): Parses the tight-binding model from the input file. It can be obtained from a Wannier90 tight-binding calculation or a toy lattice model. """ - lattice_model = self.data.get("tb_model", "") - if "_hr.dat" in lattice_model: - model_file = os.path.join(self.data.get("working_directory"), lattice_model) + lattice_model = self.data.get('tb_model', '') + if '_hr.dat' in lattice_model: + model_file = os.path.join(self.data.get('working_directory'), lattice_model) MinimalWannier90Parser().parse(model_file, self.model, self.logger) elif lattice_model in self._valid_lattice_models: - onsite_energies = self.data.get("onsite_energies") - hoppings = self.data.get("hoppings") + onsite_energies = self.data.get('onsite_energies') + hoppings = self.data.get('hoppings') ToyModels().parse( lattice_model, onsite_energies, hoppings, self.model, self.logger ) else: self.logger.error( - "Could not recognize the input tight-binding model. Please " - "check your inputs." + 'Could not recognize the input tight-binding model. Please ' + 'check your inputs.' ) return - self.logger.info("Tight-binding model parsed successfully!") + self.logger.info('Tight-binding model parsed successfully!') def prune_hoppings(self): """ Prunes the hopping matrices by setting to zero all values below a certain `prune_threshold`. """ - prune_threshold = self.data.get("prune_threshold") + prune_threshold = self.data.get('prune_threshold') if prune_threshold: pruner = Pruner(self.model) pruner.prune_by_threshold(prune_threshold, self.logger) - if self.data.get("plot_hoppings"): + if self.data.get('plot_hoppings'): plot_hopping_matrices(pruner.hopping_matrix_norms / pruner.max_value) - self.logger.info("Hopping pruning finished!") + self.logger.info('Hopping pruning finished!') def calculate_band_structure(self): """ Calculates the band structure of the tight-binding model in a given `n_k_path`. """ - n_k_path = self.data.get("n_k_path", 90) + n_k_path = self.data.get('n_k_path', 90) tb_hamiltonian = TBHamiltonian( - self.model, k_grid_type="bands", n_k_path=n_k_path + self.model, k_grid_type='bands', n_k_path=n_k_path ) special_points = tb_hamiltonian.k_path.special_points kpoints = tb_hamiltonian.kpoints eigenvalues, _ = tb_hamiltonian.diagonalize(kpoints) plot_band_structure(eigenvalues, tb_hamiltonian, special_points) - self.logger.info("Band structure calculation finished!") + self.logger.info('Band structure calculation finished!') def gaussian_convolution( self, @@ -165,7 +165,7 @@ def calculate_dos( orbital_dos_histogram.append(orbital_dos_contribution_histogram) if not orbital_dos_histogram: self.logger.warning( - "Problem obtaining the orbital DOS histogram. Cannot resolve DOS." + 'Problem obtaining the orbital DOS histogram. Cannot resolve DOS.' ) # We convolute the histogram to obtain a smoother orbital DOS energies, orbital_dos = self.gaussian_convolution( @@ -180,23 +180,23 @@ def bz_diagonalization(self): Diagonalizes the tight-binding model in the full Brillouin zone and returns its eigenvalues and eigenvectors. """ - k_grid = self.data.get("k_grid", [1, 1, 1]) - tb_hamiltonian = TBHamiltonian(self.model, k_grid_type="full_bz", k_grid=k_grid) + k_grid = self.data.get('k_grid', [1, 1, 1]) + tb_hamiltonian = TBHamiltonian(self.model, k_grid_type='full_bz', k_grid=k_grid) kpoints = tb_hamiltonian.kpoints eigenvalues, eigenvectors = tb_hamiltonian.diagonalize(kpoints) # Calculating and plotting DOS - if self.data.get("dos"): + if self.data.get('dos'): bins = int(np.linalg.norm(k_grid)) - width = self.data.get("dos_gaussian_width") - delta_energy = self.data.get("dos_delta_energy") + width = self.data.get('dos_gaussian_width') + delta_energy = self.data.get('dos_delta_energy') energies, orbital_dos, total_dos = self.calculate_dos( eigenvalues, eigenvectors, bins, width, delta_energy ) plot_dos(energies, orbital_dos, total_dos) - self.logger.info("DOS calculation finished!") + self.logger.info('DOS calculation finished!') - self.logger.info("BZ diagonalization calculation finished!") + self.logger.info('BZ diagonalization calculation finished!') return eigenvalues, eigenvectors def run(self): @@ -204,7 +204,7 @@ def run(self): self.prune_hoppings() - if self.data.get("plot_bands"): + if self.data.get('plot_bands'): self.calculate_band_structure() self.bz_diagonalization() diff --git a/src/pyssmf/schema.py b/src/pyssmf/schema.py index 3101ee2..6174c2c 100644 --- a/src/pyssmf/schema.py +++ b/src/pyssmf/schema.py @@ -35,7 +35,7 @@ class System(MSection): labels = Quantity( type=str, - shape=["n_atoms"], + shape=['n_atoms'], description=""" List containing the labels of the atoms. In the usual case, these correspond to the chemical symbols of the atoms. One can also append an index if there is a @@ -50,8 +50,8 @@ class System(MSection): positions = Quantity( type=np.float64, - shape=["n_atoms", 3], - unit="angstrom", + shape=['n_atoms', 3], + unit='angstrom', description=""" Positions of all the species, in cartesian coordinates. This metadata defines a configuration and is therefore required. For alloys where concentrations of @@ -63,7 +63,7 @@ class System(MSection): lattice_vectors = Quantity( type=np.float64, shape=[3, 3], - unit="angstrom", + unit='angstrom', description=""" Lattice vectors of the simulation cell in cartesian coordinates. The last (fastest) index runs over the $x,y,z$ Cartesian coordinates, and the first @@ -74,7 +74,7 @@ class System(MSection): reciprocal_lattice_vectors = Quantity( type=np.float64, shape=[3, 3], - unit="1/angstrom", + unit='1/angstrom', description=""" Reciprocal lattice vectors of the simulation cell, in cartesian coordinates and with the 2 $pi$ pre-factor. The first index runs over the $x,y,z$ Cartesian coordinates, and the second index runs @@ -107,8 +107,8 @@ class BravaisLattice(MSection): points = Quantity( type=np.float64, - shape=["n_points", 3], - unit="angstrom", + shape=['n_points', 3], + unit='angstrom', description=""" Values of the Bravais lattice points used to obtain the hopping integrals. They are sorted from smaller to larger values of the norm. @@ -145,7 +145,7 @@ class Model(MSection): degeneracy_factors = Quantity( type=np.int32, - shape=["n_points"], + shape=['n_points'], description=""" Degeneracy of each Bravais lattice point. """, @@ -153,8 +153,8 @@ class Model(MSection): onsite_energies = Quantity( type=np.float64, - shape=["n_orbitals"], - unit="eV", + shape=['n_orbitals'], + unit='eV', description=""" Values of the onsite energies for each orbital. """, @@ -162,8 +162,8 @@ class Model(MSection): hopping_matrix = Quantity( type=np.complex128, - shape=["n_points", "n_orbitals", "n_orbitals"], - unit="eV", + shape=['n_points', 'n_orbitals', 'n_orbitals'], + unit='eV', description=""" Real space hopping matrix for each Bravais lattice point as a matrix of dimension (n_orbitals * n_orbitals). diff --git a/src/pyssmf/tb_hamiltonian.py b/src/pyssmf/tb_hamiltonian.py index fd51cfd..f9b8191 100644 --- a/src/pyssmf/tb_hamiltonian.py +++ b/src/pyssmf/tb_hamiltonian.py @@ -67,7 +67,7 @@ def set_ase_atoms(self) -> ase.Atoms: try: formula = Formula(atoms.get_chemical_formula()) - self.model.bravais_lattice.formula_hill = formula.format("hill") + self.model.bravais_lattice.formula_hill = formula.format('hill') except Exception: pass @@ -111,7 +111,7 @@ def k_mesh(self) -> np.ndarray: class TBHamiltonian(KSampling): - _valid_k_grid_types = ["bands", "full_bz"] + _valid_k_grid_types = ['bands', 'full_bz'] def __init__( self, @@ -137,9 +137,9 @@ def __init__( super().__init__(model, n_k_path, k_grid) self.k_grid_type = k_grid_type self.kpoints = np.empty((0, 3)) # initializing for mypy - if k_grid_type == "bands": + if k_grid_type == 'bands': self.kpoints = 2 * np.pi * self.k_path.cartesian_kpts() - elif k_grid_type == "full_bz": + elif k_grid_type == 'full_bz': self.kpoints = 2 * np.pi * self.k_mesh self.n_orbitals = self.model.n_orbitals self.n_k_points = len(self.kpoints) @@ -148,13 +148,13 @@ def __init__( def __repr__(self) -> str: cls_name = self.__class__.__name__ args = [ - f"n_orbitals={self.n_orbitals}", - f"n_k_points={self.n_k_points}", - f"n_r_points={self.n_r_points}", - f"k_grid_type={self.k_grid_type}", + f'n_orbitals={self.n_orbitals}', + f'n_k_points={self.n_k_points}', + f'n_r_points={self.n_r_points}', + f'k_grid_type={self.k_grid_type}', ] if self.spacegroup: - args.append(f"spacegroup.no={self.spacegroup.no}") + args.append(f'spacegroup.no={self.spacegroup.no}') return f"{cls_name}({', '.join(filter(None, args))})" def hamiltonian(self, kpoints: np.ndarray) -> np.ndarray: diff --git a/src/pyssmf/utils.py b/src/pyssmf/utils.py index 71f8aec..1287c35 100644 --- a/src/pyssmf/utils.py +++ b/src/pyssmf/utils.py @@ -21,7 +21,7 @@ import numpy as np -def get_files(pattern: str, filepath: str, stripname: str = "", deep: bool = True): +def get_files(pattern: str, filepath: str, stripname: str = '', deep: bool = True): """ Get files following the `pattern` with respect to the file `stripname` (usually this being the mainfile of the given parser) up to / down from the `filepath` (`deep=True` going @@ -37,8 +37,8 @@ def get_files(pattern: str, filepath: str, stripname: str = "", deep: bool = Tru list: List of found files. """ for _ in range(10): - filenames = glob(f"{os.path.dirname(filepath)}/{pattern}") - pattern = os.path.join("**" if deep else "..", pattern) + filenames = glob(f'{os.path.dirname(filepath)}/{pattern}') + pattern = os.path.join('**' if deep else '..', pattern) if filenames: break diff --git a/src/pyssmf/visualization.py b/src/pyssmf/visualization.py index 10344f0..7cd75ad 100644 --- a/src/pyssmf/visualization.py +++ b/src/pyssmf/visualization.py @@ -47,17 +47,17 @@ def update(val): for i, ax in enumerate(axes): ax.clear() if idx + i < total_matrices: - ax.imshow(matrices[idx + i], cmap="inferno", vmin=0, vmax=max_value) - ax.set_title(f"$N_R$ = {idx + i}") - ax.axis("off") + ax.imshow(matrices[idx + i], cmap='inferno', vmin=0, vmax=max_value) + ax.set_title(f'$N_R$ = {idx + i}') + ax.axis('off') fig.canvas.draw_idle() # Create the slider - ax_slider = plt.axes([0.2, 0.01, 0.65, 0.03], facecolor="lightgoldenrodyellow") + ax_slider = plt.axes([0.2, 0.01, 0.65, 0.03], facecolor='lightgoldenrodyellow') slider = Slider( ax_slider, - "Start $N_R$", + 'Start $N_R$', 0, total_matrices - display_count, valinit=0, @@ -72,7 +72,7 @@ def update(val): cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7]) fig.colorbar( plt.cm.ScalarMappable( - cmap="inferno", norm=plt.Normalize(vmin=0, vmax=max_value) + cmap='inferno', norm=plt.Normalize(vmin=0, vmax=max_value) ), cax=cbar_ax, ) @@ -102,12 +102,12 @@ def plot_band_structure(eigenvalues, tb_hamiltonian, special_points=None): plt.plot( np.arange(len(eigenvalues)), eigenvalues[:, band_idx], - label=f"Band {band_idx + 1}", + label=f'Band {band_idx + 1}', ) # Customize x-axis labeling based on special points if special_points is not None: - x_labels = [""] * (len(special_points)) # Initialize labels with empty strings + x_labels = [''] * (len(special_points)) # Initialize labels with empty strings x_ticks = [] i = 0 @@ -122,9 +122,9 @@ def plot_band_structure(eigenvalues, tb_hamiltonian, special_points=None): plt.xticks(x_ticks, x_labels) plt.xlim(0, len(eigenvalues) - 1) - plt.xlabel("k-points") - plt.ylabel("Energy (eV)") - plt.title("Band Structure") + plt.xlabel('k-points') + plt.ylabel('Energy (eV)') + plt.title('Band Structure') plt.legend() plt.grid(True) @@ -142,13 +142,13 @@ def plot_dos(energies, orbital_dos, total_dos): """ # Create a figure plt.figure(figsize=(8, 6)) - plt.plot(energies, total_dos, label="Total DOS", color="k", linewidth=3.5) + plt.plot(energies, total_dos, label='Total DOS', color='k', linewidth=3.5) # Plot orbital-resolved DOS for i, orb_dos in enumerate(orbital_dos): - plt.plot(energies, orb_dos, label=f"Orbital {i + 1}") + plt.plot(energies, orb_dos, label=f'Orbital {i + 1}') - plt.xlabel("Energy (eV)") - plt.ylabel("Density of States (DOS)") + plt.xlabel('Energy (eV)') + plt.ylabel('Density of States (DOS)') plt.legend() plt.show() diff --git a/tests/conftest.py b/tests/conftest.py index 8976c26..01fd1f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def get_template_system(): system = System() system.n_atoms = 2 - system.labels = ["H", "O"] + system.labels = ['H', 'O'] system.positions = np.array([[0.0, 0.0, 0.0], [0.1, 0.1, 0.1]]) system.lattice_vectors = np.identity(3) return system diff --git a/tests/test_parsing.py b/tests/test_parsing.py index aa849db..9107725 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -31,7 +31,7 @@ def test_wannier90_parser(): model = Model() filepath = os.path.join( - os.path.dirname(__file__), "data/wannier90/wannier90_hr.dat" + os.path.dirname(__file__), 'data/wannier90/wannier90_hr.dat' ) MinimalWannier90Parser().parse(filepath, model, None) @@ -56,7 +56,7 @@ def test_wannier90_parser(): # System system = bravais_lattice.system - assert system.labels[:4] == ["Nb", "Nb", "Ta", "Ta"] + assert system.labels[:4] == ['Nb', 'Nb', 'Ta', 'Ta'] assert np.array_equal( system.positions[0].magnitude, np.array([0.0, 7.29193, 22.06006]) ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 1d18a8c..0118675 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -20,7 +20,7 @@ def test_system(example_system): """Tests whether an instance of System is created and properly populated.""" assert example_system.n_atoms == 2 - assert example_system.labels == ["H", "O"] + assert example_system.labels == ['H', 'O'] assert np.array_equal( example_system.positions.magnitude, np.array([[0.0, 0.0, 0.0], [0.1, 0.1, 0.1]]) )