Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 6, 2024
1 parent d64dea2 commit 6d45741
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,27 +257,29 @@ def equivalent_site_index_and_transform(self, psite):
Equivalent site in the unit cell, translations and symmetry transformation.
"""
# Get the index of the site in the unit cell of which the PeriodicSite psite is a replica.
isite = 0
site_idx = 0
try:
isite = self.structure_environments.structure.index(psite)
site_idx = self.structure_environments.structure.index(psite)
except ValueError:
try:
uc_psite = psite.to_unit_cell()
isite = self.structure_environments.structure.index(uc_psite)
site_idx = self.structure_environments.structure.index(uc_psite)
except ValueError:
for isite2, site2 in enumerate(self.structure_environments.structure):
if psite.is_periodic_image(site2):
isite = isite2
site_idx = isite2
break
# Get the translation between psite and its corresponding site in the unit cell (Translation I)
this_site = self.structure_environments.structure[isite]
dthis_site = psite.frac_coords - this_site.frac_coords
this_site = self.structure_environments.structure[site_idx]
dist_this_site = psite.frac_coords - this_site.frac_coords
# Get the translation between the equivalent site for which the neighbors have been computed and the site in
# the unit cell that corresponds to psite (Translation II)
equiv_site = self.structure_environments.structure[self.structure_environments.sites_map[isite]].to_unit_cell()
equiv_site = self.structure_environments.structure[
self.structure_environments.sites_map[site_idx]
].to_unit_cell()
# equivsite = self.structure_environments.structure[self.structure_environments.sites_map[isite]]
dequivsite = (
self.structure_environments.structure[self.structure_environments.sites_map[isite]].frac_coords
dist_equiv_site = (
self.structure_environments.structure[self.structure_environments.sites_map[site_idx]].frac_coords
- equiv_site.frac_coords
)
found = False
Expand Down Expand Up @@ -317,7 +319,9 @@ def equivalent_site_index_and_transform(self, psite):
break
if not found:
raise EquivalentSiteSearchError(psite)
return self.structure_environments.sites_map[isite], dequivsite, dthis_site + d_this_site2, sym_trafo

equivalent_site_map = self.structure_environments.sites_map[site_idx]
return equivalent_site_map, dist_equiv_site, dist_this_site + d_this_site2, sym_trafo

@abc.abstractmethod
def get_site_neighbors(self, site):
Expand Down Expand Up @@ -408,19 +412,19 @@ def get_site_ce_fractions_and_neighbors(self, site, full_ce_info=False, strategy
The list of neighbors of the site. For complex strategies, where one allows multiple solutions, this
can return a list of list of neighbors.
"""
isite, dequivsite, dthissite, mysym = self.equivalent_site_index_and_transform(site)
site_idx, dist_equiv_site, dist_this_site, mysym = self.equivalent_site_index_and_transform(site)
geoms_and_maps_list = self.get_site_coordination_environments_fractions(
site=site,
isite=isite,
dequivsite=dequivsite,
dthissite=dthissite,
isite=site_idx,
dequivsite=dist_equiv_site,
dthissite=dist_this_site,
mysym=mysym,
return_maps=True,
return_strategy_dict_info=True,
)
if geoms_and_maps_list is None:
return None
site_nbs_sets = self.structure_environments.neighbors_sets[isite]
site_nbs_sets = self.structure_environments.neighbors_sets[site_idx]
ce_and_neighbors = []
for fractions_dict in geoms_and_maps_list:
ce_map = fractions_dict["ce_map"]
Expand Down
14 changes: 7 additions & 7 deletions pymatgen/ext/matproj_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,16 +1276,16 @@ def get_cohesive_energy(self, material_id, per_atom=False):
Cohesive energy (eV).
"""
entry = self.get_entry_by_material_id(material_id)
ebulk = entry.energy / entry.composition.get_integer_formula_and_factor()[1]
e_bulk = entry.energy / entry.composition.get_integer_formula_and_factor()[1]
comp_dict = entry.composition.reduced_composition.as_dict()

isolated_atom_e_sum, n = 0, 0
isolated_atom_e_sum = 0
for el in comp_dict:
e = self._make_request(f"/element/{el}/tasks/isolated_atom", mp_decode=False)[0]
isolated_atom_e_sum += e["output"]["final_energy_per_atom"] * comp_dict[el]
n += comp_dict[el]
ecoh_per_formula = isolated_atom_e_sum - ebulk
return ecoh_per_formula / n if per_atom else ecoh_per_formula
ent = self._make_request(f"/element/{el}/tasks/isolated_atom", mp_decode=False)[0]
isolated_atom_e_sum += ent["output"]["final_energy_per_atom"] * comp_dict[el]
e_coh_per_formula = isolated_atom_e_sum - e_bulk
n_atoms = entry.composition.num_atoms
return e_coh_per_formula / n_atoms if per_atom else e_coh_per_formula

def get_reaction(self, reactants, products):
"""Get a reaction from the Materials Project.
Expand Down
6 changes: 1 addition & 5 deletions pymatgen/io/abinit/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ def __str__(self):
float_decimal = 0

if isinstance(value, np.ndarray):
n = 1
for i in np.shape(value):
n *= i
value = np.reshape(value, n)
value = list(value)
value = list(value.flatten())

# values in lists
if isinstance(value, (list, tuple)):
Expand Down
27 changes: 13 additions & 14 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def __init__(
"""
self.loops = loops
self.data = data
# AJ (@computron) says: CIF Block names cannot be
# more than 75 characters or you get an Exception
# AJ (@computron) says: CIF Block names can't be more than 75 characters or you get an Exception
self.header = header[:74]

def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -199,30 +198,30 @@ def from_str(cls, string: str) -> Self:
loops: list[list[str]] = []

while deq:
_string = deq.popleft()
# cif keys aren't in quotes, so show up as _string[0]
if _string[0] == "_eof":
_str = deq.popleft()
# cif keys aren't in quotes, so show up as _str[0]
if _str[0] == "_eof":
break

if _string[0].startswith("_"):
if _str[0].startswith("_"):
try:
data[_string[0]] = "".join(deq.popleft())
data[_str[0]] = "".join(deq.popleft())
except IndexError:
data[_string[0]] = ""
data[_str[0]] = ""

elif _string[0].startswith("loop_"):
elif _str[0].startswith("loop_"):
columns: list[str] = []
items: list[str] = []
while deq:
_string = deq[0]
if _string[0].startswith("loop_") or not _string[0].startswith("_"):
_str = deq[0]
if _str[0].startswith("loop_") or not _str[0].startswith("_"):
break
columns.append("".join(deq.popleft()))
data[columns[-1]] = []

while deq:
_string = deq[0]
if _string[0].startswith(("loop_", "_")):
_str = deq[0]
if _str[0].startswith(("loop_", "_")):
break
items.append("".join(deq.popleft()))

Expand All @@ -232,7 +231,7 @@ def from_str(cls, string: str) -> Self:
for k, v in zip(columns * n, items):
data[k].append(v.strip())

elif issue := "".join(_string).strip():
elif issue := "".join(_str).strip():
warnings.warn(f"Possible issue in CIF file at line: {issue}")

return cls(data, loops, header)
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/symmetry/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def get_lattice_type(self) -> LatticeType:
Returns:
str: Lattice type for structure
"""
n = self._space_group_data["number"]
spg_num = self._space_group_data["number"]
system = self.get_crystal_system()
if n in [146, 148, 155, 160, 161, 166, 167]:
if spg_num in (146, 148, 155, 160, 161, 166, 167):
return "rhombohedral"
if system == "trigonal":
return "hexagonal"
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/symmetry/maggroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _parse_operators(b):
return None
raw_symops = [b[i : i + 6] for i in range(0, len(b), 6)]

symops = []
symm_ops = []

for r in raw_symops:
point_operator = _get_point_operator(r[0])
Expand All @@ -173,9 +173,9 @@ def _parse_operators(b):
)
if time_reversal == -1:
seitz += "'"
symops.append({"op": op, "str": seitz})
symm_ops.append({"op": op, "str": seitz})

return symops
return symm_ops

def _parse_wyckoff(b):
"""Parse compact binary representation into list of Wyckoff sites."""
Expand Down
7 changes: 3 additions & 4 deletions pymatgen/vis/structure_vtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,12 @@ def add_bonds(self, neighbors, center, color=None, opacity=None, radius=0.1):
"""
points = vtk.vtkPoints()
points.InsertPoint(0, center.x, center.y, center.z)
n = len(neighbors)
lines = vtk.vtkCellArray()
for i in range(n):
points.InsertPoint(i + 1, neighbors[i].coords)
for idx, neighbor in enumerate(neighbors):
points.InsertPoint(idx + 1, neighbor.coords)
lines.InsertNextCell(2)
lines.InsertCellPoint(0)
lines.InsertCellPoint(i + 1)
lines.InsertCellPoint(idx + 1)
pd = vtk.vtkPolyData()
pd.SetPoints(points)
pd.SetLines(lines)
Expand Down
4 changes: 2 additions & 2 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def update_changelog(ctx: Context, version: str | None = None, dry_run: bool = F
json_resp = response.json()
if body := json_resp["body"]:
for ll in map(str.strip, body.split("\n")):
if ll in ["", "## Summary"]:
if ll in ("", "## Summary"):
continue
if ll.startswith(("## Checklist", "## TODO")):
break
Expand Down Expand Up @@ -248,5 +248,5 @@ def lint(ctx: Context) -> None:
Args:
ctx (invoke.Context): The context object.
"""
for cmd in ["ruff", "mypy", "ruff format"]:
for cmd in ("ruff", "mypy", "ruff format"):
ctx.run(f"{cmd} pymatgen")
12 changes: 6 additions & 6 deletions tests/symmetry/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,12 @@ def setUp(self):
self.sg1 = SpacegroupAnalyzer(self.structure, 0.001).get_space_group_operations()

def test_are_symmetrically_equivalent(self):
sites1 = [self.structure[idx] for idx in [0, 1]]
sites2 = [self.structure[idx] for idx in [2, 3]]
sites1 = [self.structure[idx] for idx in (0, 1)]
sites2 = [self.structure[idx] for idx in (2, 3)]
assert self.sg1.are_symmetrically_equivalent(sites1, sites2, 1e-3)

sites1 = [self.structure[idx] for idx in [0, 1]]
sites2 = [self.structure[idx] for idx in [0, 2]]
sites1 = [self.structure[idx] for idx in (0, 1)]
sites2 = [self.structure[idx] for idx in (0, 2)]
assert not self.sg1.are_symmetrically_equivalent(sites1, sites2, 1e-3)


Expand Down Expand Up @@ -593,7 +593,7 @@ def test_symmetrize_molecule2(self):
assert pa3.get_pointgroup().sch_symbol == "Ci"

def test_get_kpoint_weights(self):
for name in ["SrTiO3", "LiFePO4", "Graphite"]:
for name in ("SrTiO3", "LiFePO4", "Graphite"):
struct = PymatgenTest.get_structure(name)
spga = SpacegroupAnalyzer(struct)
ir_mesh = spga.get_ir_reciprocal_mesh((4, 4, 4))
Expand All @@ -602,7 +602,7 @@ def test_get_kpoint_weights(self):
for expected, weight in zip(weights, spga.get_kpoint_weights([i[0] for i in ir_mesh])):
assert weight == approx(expected)

for name in ["SrTiO3", "LiFePO4", "Graphite"]:
for name in ("SrTiO3", "LiFePO4", "Graphite"):
struct = PymatgenTest.get_structure(name)
spga = SpacegroupAnalyzer(struct)
ir_mesh = spga.get_ir_reciprocal_mesh((1, 2, 3))
Expand Down
2 changes: 1 addition & 1 deletion tests/transformations/test_standard_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_from_scaling_factors(self):
def test_from_boundary_distance(self):
struct_cubic = Structure.from_spacegroup("Pm-3m", 4 * np.eye(3), ["H"], [[0, 0, 0]])

for struct in [struct_cubic, self.struct]:
for struct in (struct_cubic, self.struct):
for min_dist in range(6, 19, 4):
trafo = SupercellTransformation.from_boundary_distance(
structure=struct, min_boundary_dist=min_dist, allow_rotation=False
Expand Down

0 comments on commit 6d45741

Please sign in to comment.