Skip to content

Commit

Permalink
Merge pull request #320 from laserkelvin/self-loop-mask-fix
Browse files Browse the repository at this point in the history
Self loop mask fix
  • Loading branch information
laserkelvin authored Nov 22, 2024
2 parents e446eea + 5458b12 commit 895c4ab
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
7 changes: 5 additions & 2 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,12 @@ def __call__(self, data: DataDict) -> DataDict:
raise RuntimeError(f"Requested backend f{self.backend} not available.")
data.update(graph_props)
if not self.allow_self_loops:
mask = data["src_nodes"] == data["dst_nodes"]
# this looks for src and dst nodes that are the same, i.e. self-loops
loop_mask = data["src_nodes"] == data["dst_nodes"]
# only mask out self-loops within the same image
mask &= data["unit_offsets"].sum(dim=-1) == 0
image_mask = data["images"].sum(dim=-1) == 0
# we negate the mask because we want to *exclude* what we've found
mask = ~torch.logical_and(loop_mask, image_mask)
# apply mask to each of the tensors that depend on edges
for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]:
data[key] = data[key][mask]
Expand Down
33 changes: 29 additions & 4 deletions matsciml/datasets/transforms/tests/test_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@
)
@pytest.mark.parametrize("self_loops", [True, False])
@pytest.mark.parametrize("backend", ["pymatgen", "ase"])
@pytest.mark.parametrize(
"cutoff_radius", [6.0, 9.0, 15.0]
) # TODO figure out why pmg fails on 3
@pytest.mark.parametrize("cutoff_radius", [6.0, 9.0, 15.0])
def test_periodic_generation(
coords: np.ndarray,
cell: np.ndarray,
Expand All @@ -84,4 +82,31 @@ def test_periodic_generation(
counts = Counter(src_nodes)
for index, count in counts.items():
if not self_loops:
assert count < 10, print(f"Node {index} has too many counts. {src_nodes}")
# TODO pymatgen backend fails this check at cutoff radius = 15
# and I don't know why
assert count <= 10, f"Node {index} has too many counts. {src_nodes}"


def test_self_loop_condition():
"""Tests for whether the self-loops exclusion is behaving as intended"""
coords = torch.FloatTensor(alumina.cart_coords)
cell = torch.FloatTensor(alumina.lattice.matrix)
num_atoms = coords.size(0)
atomic_numbers = torch.ones(num_atoms)
packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers}
no_loop_transform = PeriodicPropertiesTransform(
cutoff_radius=6.0, backend="ase", allow_self_loops=False
)
no_loop_result = no_loop_transform(packed_data)
# since it's no self loops this sum should be zero
same_node = no_loop_result["src_nodes"] == no_loop_result["dst_nodes"]
same_image = no_loop_result["images"].sum(dim=-1) == 0
assert torch.sum(torch.logical_and(same_node, same_image)) == 0
allow_loop_transform = PeriodicPropertiesTransform(
cutoff_radius=6.0, backend="ase", allow_self_loops=True
)
loop_result = allow_loop_transform(packed_data)
# there should be some self-loops in this graph
same_node = loop_result["src_nodes"] == loop_result["dst_nodes"]
same_image = loop_result["images"].sum(dim=-1) == 0
assert torch.sum(torch.logical_and(same_node, same_image)) > 0

0 comments on commit 895c4ab

Please sign in to comment.