Skip to content

Commit

Permalink
Ruff formatted python modules
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed May 28, 2024
1 parent c216a76 commit 81af496
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 146 deletions.
4 changes: 2 additions & 2 deletions src/pyssmf/hopping_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def prune_by_threshold(
# problems for hoppings structures which are not so trivial.
if not self.max_value and not self.hopping_matrix_norms:
self.logger.warning(
"Could not extract the hopping_matrix norms and their max_value."
'Could not extract the hopping_matrix norms and their max_value.'
)
return
threshold = threshold_factor * self.max_value
Expand Down Expand Up @@ -88,7 +88,7 @@ def prune_by_threshold(
]
except Exception:
self.logger.warning(
"Could not update the model parameters after pruning."
'Could not update the model parameters after pruning.'
)

self.update_norms_and_max_value()
102 changes: 51 additions & 51 deletions src/pyssmf/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class ValidLatticeModels(ABC):
def __init__(self, logger: logging.Logger = logging.getLogger(__name__)):
self.logger = logger
self._valid_lattice_models = [
"linear",
"square",
"honeycomb",
"triangular",
'linear',
'square',
'honeycomb',
'triangular',
] # TODO extend this


Expand Down Expand Up @@ -64,34 +64,34 @@ def __init__(self, **kwargs):
"""
super().__init__()
# Initializing data
data = {"code": "SSMF"}
data = {'code': 'SSMF'}

# Check working_directory and stores it in data
if not kwargs.get("working_directory"):
if not kwargs.get('working_directory'):
self.logger.error(
"Could not find specified the working_directory in the input."
'Could not find specified the working_directory in the input.'
)
return
data["working_directory"] = kwargs.get("working_directory")
data['working_directory'] = kwargs.get('working_directory')

# Read from input file if provided in the argument
read_input_file = kwargs.get("read_from_input_file", False)
read_input_file = kwargs.get('read_from_input_file', False)
if read_input_file:
data["input_file"] = kwargs.get("input_file")
data['input_file'] = kwargs.get('input_file')
input_file = os.path.join(
kwargs.get("working_directory"), kwargs.get("input_file")
kwargs.get('working_directory'), kwargs.get('input_file')
)
data["input_data"] = self.read_from_file(input_file)
data['input_data'] = self.read_from_file(input_file)
else:
# We check whether a model_file has been defined or a model_label
if kwargs.get("tb_model_file"):
data["tb_model"] = kwargs.get("tb_model_file")
if kwargs.get('tb_model_file'):
data['tb_model'] = kwargs.get('tb_model_file')
# pruning only applies for tb_model_file cases
if kwargs.get("pruning", False):
data["prune_threshold"] = kwargs.get("prune_threshold", 0.01)
elif kwargs.get("lattice_model") in self._valid_lattice_models:
data["tb_model"] = kwargs.get("lattice_model")
hoppings = kwargs.get("hoppings", [])
if kwargs.get('pruning', False):
data['prune_threshold'] = kwargs.get('prune_threshold', 0.01)
elif kwargs.get('lattice_model') in self._valid_lattice_models:
data['tb_model'] = kwargs.get('lattice_model')
hoppings = kwargs.get('hoppings', [])
# We check if 'hoppings' was empty
if len(hoppings) == 0:
n_hoppings = 1
Expand All @@ -101,8 +101,8 @@ def __init__(self, **kwargs):
n_hoppings = len(hoppings)
if n_hoppings > 3:
self.logger.error(
"Maximum n_hoppings models supported is 3. Please, select "
"a smaller number."
'Maximum n_hoppings models supported is 3. Please, select '
'a smaller number.'
)
return
n_orbitals = len(hoppings[0])
Expand All @@ -112,48 +112,48 @@ def __init__(self, **kwargs):
for hop_point in hoppings
):
self.logger.error(
"Dimensions of each hopping matrix do not coincide with"
"(n_orbitals, n_orbitals).",
data={"n_orbitals": n_orbitals},
'Dimensions of each hopping matrix do not coincide with'
'(n_orbitals, n_orbitals).',
data={'n_orbitals': n_orbitals},
)
# TODO improve this
onsite_energies = kwargs.get("onsite_energies", [])
onsite_energies = kwargs.get('onsite_energies', [])
if len(onsite_energies) == 0:
onsite_energies = [0.0] * n_orbitals
data["hoppings"] = hoppings
data["onsite_energies"] = onsite_energies
data["n_hoppings"] = n_hoppings
data["n_orbitals"] = n_orbitals
data['hoppings'] = hoppings
data['onsite_energies'] = onsite_energies
data['n_hoppings'] = n_hoppings
data['n_orbitals'] = n_orbitals
else:
self.logger.error(
"Could not find the initial model. Please, check your inputs: "
"1) define `model_file` pointing to your Wannier90 `*_hr.dat` "
"hoppings file, or 2) specify the `lattice_model` to study among the "
"accepted values.",
data={"lattice_model": self._valid_lattice_models},
'Could not find the initial model. Please, check your inputs: '
'1) define `model_file` pointing to your Wannier90 `*_hr.dat` '
'hoppings file, or 2) specify the `lattice_model` to study among the '
'accepted values.',
data={'lattice_model': self._valid_lattice_models},
)

# KGrids
# For band structure calculations
data["n_k_path"] = kwargs.get("n_k_path", 90)
data['n_k_path'] = kwargs.get('n_k_path', 90)
# For full_bz diagonalization
data["k_grid"] = kwargs.get("k_grid", [1, 1, 1])
data['k_grid'] = kwargs.get('k_grid', [1, 1, 1])

# Plotting arguments
data["plot_hoppings"] = kwargs.get("plot_hoppings", False)
data["plot_bands"] = kwargs.get("plot_bands", False)
data['plot_hoppings'] = kwargs.get('plot_hoppings', False)
data['plot_bands'] = kwargs.get('plot_bands', False)
# DOS calculation and plotting
data["dos"] = kwargs.get("dos", False)
data["dos_gaussian_width"] = kwargs.get("dos_gaussian_width", 0.1)
data["dos_delta_energy"] = kwargs.get("dos_delta_energy", 0.01)
data['dos'] = kwargs.get('dos', False)
data['dos_gaussian_width'] = kwargs.get('dos_gaussian_width', 0.1)
data['dos_delta_energy'] = kwargs.get('dos_delta_energy', 0.01)
self.data = data
self.to_json()

def to_json(self):
"""
Stores the input data in a JSON file in the working_directory.
"""
with open(f"{self.data.get('working_directory')}/input_ssmf.json", "w") as file:
with open(f"{self.data.get('working_directory')}/input_ssmf.json", 'w') as file:
json.dump(self.data, file, indent=4)

def read_from_file(self, input_file: str) -> dict:
Expand All @@ -167,22 +167,22 @@ def read_from_file(self, input_file: str) -> dict:
(dict): dictionary with the input data read from the JSON input file.
"""
try:
with open(input_file, "r") as file:
with open(input_file, 'r') as file:
input_data = json.load(file)
except FileNotFoundError:
self.logger.error(
"Input file not found.",
extra={"input_file": input_file},
'Input file not found.',
extra={'input_file': input_file},
)
except json.JSONDecodeError:
self.logger.error(
"Failed to decode JSON in input file.",
extra={"input_file": input_file},
'Failed to decode JSON in input file.',
extra={'input_file': input_file},
)
code_name = input_data.get("code", "")
if code_name != "SSMF":
code_name = input_data.get('code', '')
if code_name != 'SSMF':
self.logger.error(
"Could not recognize the input JSON file as readable by the SSMF code.",
extra={"input_file": input_file},
'Could not recognize the input JSON file as readable by the SSMF code.',
extra={'input_file': input_file},
)
return input_data
52 changes: 26 additions & 26 deletions src/pyssmf/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .schema import System, BravaisLattice, Model
from .utils import get_files

re_n = r"[\n\r]"
re_n = r'[\n\r]'


class WOutParser(TextParser):
Expand All @@ -40,19 +40,19 @@ def __init__(self):

def init_quantities(self):
structure_quantities = [
Quantity("labels", r"\|\s*([A-Z][a-z]*)", repeats=True),
Quantity('labels', r'\|\s*([A-Z][a-z]*)', repeats=True),
Quantity(
"positions",
r"\|\s*([\-\d\.]+)\s*([\-\d\.]+)\s*([\-\d\.]+)",
'positions',
r'\|\s*([\-\d\.]+)\s*([\-\d\.]+)\s*([\-\d\.]+)',
repeats=True,
),
]

self._quantities = [
Quantity("lattice_vectors", r"\s*a_\d\s*([\d\-\s\.]+)", repeats=True),
Quantity('lattice_vectors', r'\s*a_\d\s*([\d\-\s\.]+)', repeats=True),
Quantity(
"structure",
rf"(\s*Fractional Coordinate[\s\S]+?)(?:{re_n}\s*(PROJECTIONS|K-POINT GRID))",
'structure',
rf'(\s*Fractional Coordinate[\s\S]+?)(?:{re_n}\s*(PROJECTIONS|K-POINT GRID))',
repeats=False,
sub_parser=TextParser(quantities=structure_quantities),
),
Expand All @@ -70,8 +70,8 @@ def __init__(self):

def init_quantities(self):
self._quantities = [
Quantity("degeneracy_factors", r"\s*written on[\s\w]*:\d*:\d*\s*([\d\s]+)"),
Quantity("hoppings", r"\s*([-\d\s.]+)", repeats=False),
Quantity('degeneracy_factors', r'\s*written on[\s\w]*:\d*:\d*\s*([\d\s]+)'),
Quantity('hoppings', r'\s*([-\d\s.]+)', repeats=False),
]


Expand All @@ -95,28 +95,28 @@ def parse_system(self):
"""
sec_system = self.model.m_create(BravaisLattice).m_create(System)

structure = self.wout_parser.get("structure")
structure = self.wout_parser.get('structure')
if structure is None:
self.logger.error("Error parsing the structure from .wout")
self.logger.error('Error parsing the structure from .wout')
return
if self.wout_parser.get("lattice_vectors", []):
if self.wout_parser.get('lattice_vectors', []):
lattice_vectors = np.vstack(
self.wout_parser.get("lattice_vectors", [])[-3:]
self.wout_parser.get('lattice_vectors', [])[-3:]
)
sec_system.lattice_vectors = lattice_vectors * ureg.angstrom
sec_system.periodic = [True, True, True]
sec_system.labels = structure.get("labels")
if structure.get("positions") is not None:
sec_system.positions = structure.get("positions") * ureg.angstrom
sec_system.labels = structure.get('labels')
if structure.get('positions') is not None:
sec_system.positions = structure.get('positions') * ureg.angstrom

def parse_hoppings(self):
"""
Parses the hoppings metadata and stores them under `Model`.
"""
bravais_lattice = self.model.bravais_lattice

deg_factors = self.hr_parser.get("degeneracy_factors", [])
full_hoppings = self.hr_parser.get("hoppings", [])
deg_factors = self.hr_parser.get('degeneracy_factors', [])
full_hoppings = self.hr_parser.get('hoppings', [])
if deg_factors is not None and full_hoppings is not None:
n_orbitals = deg_factors[0]
n_points = deg_factors[1]
Expand Down Expand Up @@ -158,9 +158,9 @@ def parse(self, filepath: str, model: Model, logger: logging.Logger = None):
logger (logging.Logger, optional): Logger object for debug messages. Defaults to None.
"""
basename = os.path.basename(filepath) # Getting filepath for *_hr.dat file
wout_files = get_files("*.wout", filepath, basename)
wout_files = get_files('*.wout', filepath, basename)
if len(wout_files) > 1:
logger.warning("Multiple `*.wout` files found; we will parse the last one.")
logger.warning('Multiple `*.wout` files found; we will parse the last one.')
mainfile = wout_files[-1] # Path to *.wout file

self.filepath = filepath
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(self):

def _bravais_vectors(self, n_neighbors, lattice_vectors):
bravais_vectors = []
if self.lattice_model == "linear":
if self.lattice_model == 'linear':
for i in range(-n_neighbors, n_neighbors + 1):
j = 0
k = 0
Expand All @@ -215,11 +215,11 @@ def linear(self):
system = bravais_lattice.m_create(System)
lattice_vectors = [[1, 0, 0], [0, 0, 0], [0, 0, 0]]
_system_map = {
"lattice_vectors": lattice_vectors,
"periodic": [True, False, False],
"n_atoms": 1,
"labels": ["X"],
"positions": [[0, 0, 0]],
'lattice_vectors': lattice_vectors,
'periodic': [True, False, False],
'n_atoms': 1,
'labels': ['X'],
'positions': [[0, 0, 0]],
}
for key in system.m_def.all_quantities.keys():
system.m_set(system.m_get_quantity_definition(key), _system_map.get(key))
Expand Down
Loading

0 comments on commit 81af496

Please sign in to comment.