Skip to content

Commit

Permalink
cleaned codes
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenD-UCB committed Mar 1, 2023
1 parent 68373c7 commit e1f49ea
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
Empty file removed __init__.py
Empty file.
36 changes: 26 additions & 10 deletions examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
datatype = torch.float32
random.seed(100)

# This runnable script show

# This runnable script shows an example to convert a Structure json dataset to graphs
# and save them. So the you don't have to do graph conversion in each training

def main():
data_path = ""
graph_dir = ""
Expand All @@ -18,14 +21,22 @@ def main():
make_graphs(data, graph_dir)


def make_graphs(data: StructureJsonData, graph_dir, train_ratio=0.8, val_ratio=0.1):
def make_graphs(
data: StructureJsonData,
graph_dir: str,
train_ratio: float = 0.8,
val_ratio: float = 0.1
):
"""
Make cifs form the MPtrj dataset
:param train_ratio:
:param val_ratio:
:param test_ratio:
:return:
Make graphs from a StructureJsonData dataset
Args:
data (StructureJsonData): a StructureJsonData
graph_dir (str): a directory to save the graphs
train_ratio (float): train ratio
val_ratio (float): val ratio
"""
utils.mkdir(graph_dir)
random.shuffle(data.keys)
labels = {}
failed_graphs = []
Expand All @@ -42,13 +53,16 @@ def make_graphs(data: StructureJsonData, graph_dir, train_ratio=0.8, val_ratio=0
failed_graphs += [(mp_id, graph_id)]
if i % 1000 == 0:
print(i)
# torch.save(graphs, 'test_MPtrj_graphs.pt')

utils.write_json(labels, os.path.join(graph_dir, "labels.json"))
utils.write_json(failed_graphs, os.path.join(graph_dir, "failed_graphs.json"))
make_partition(labels, graph_dir, train_ratio, val_ratio)


def make_one_graph(mp_id, graph_id, data, graph_dir):
"""
convert a structure to a Crystal_Graph and save it
"""
dic = data.data[mp_id].pop(graph_id)
struc = Structure.from_dict(dic.pop("structure"))
try:
Expand All @@ -60,9 +74,11 @@ def make_one_graph(mp_id, graph_id, data, graph_dir):


def make_partition(
data, graph_dir, train_ratio=0.8, val_ratio=0.1, partition_with_frame=False
data, graph_dir, train_ratio=0.8, val_ratio=0.1, partition_with_frame=False
):

"""
Make a train val test partition
"""
random.seed(42)
if partition_with_frame is False:
material_ids = list(data.keys())
Expand Down

0 comments on commit e1f49ea

Please sign in to comment.