diff --git a/megnet/data/graph.py b/megnet/data/graph.py index 8fe2744b3..f086b2344 100644 --- a/megnet/data/graph.py +++ b/megnet/data/graph.py @@ -580,6 +580,13 @@ def _generate_inputs(self, batch_index: list) -> tuple: - [ndarray]: List of indices for the start of each bond - [ndarray]: List of indices for the end of each bond """ + import collections + if isinstance(self.atom_features[0][0], collections.defaultdict): + from megnet.data.crystal import CrystalGraphDisordered + cgd = CrystalGraphDisordered() + for i in range(len(self.atom_features)): + self.atom_features[i] = cgd.atom_converter.convert( + self.atom_features[i].tolist()).tolist() # Get the features and connectivity lists for this batch feature_list_temp = itemgetter_list(self.atom_features, batch_index)