Skip to content

Commit

Permalink
Merge pull request #12 from samblau/BonDNet
Browse files Browse the repository at this point in the history
BonDNet
  • Loading branch information
samblau authored Nov 6, 2023
2 parents 7792880 + 64c7d40 commit 539d552
Show file tree
Hide file tree
Showing 14 changed files with 1,721 additions and 121 deletions.
329 changes: 329 additions & 0 deletions HiPRGen/lmdb_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
#give dgl graphs, reaction features, meta. write them into lmdb file.
#1. check expend lmdb reasonably

#give dgl graphs, reaction features, meta. write them into lmdb file.
#1. check expend lmdb reasonably

from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
import pickle
import lmdb
from torch.utils.data import random_split
import multiprocessing as mp
import os
import pickle
from tqdm import tqdm
import glob


class LmdbDataset(Dataset):
"""
Dataset class to
1. write Reaction networks objecs to lmdb
2. load lmdb files
"""
def __init__(self, config, transform=None):
super(LmdbDataset, self).__init__()

self.config = config
self.path = Path(self.config["src"])

#Get metadata in case
#self.metadata_path = self.path.parent / "metadata.npz"
self.env = self.connect_db(self.path)

# If "length" encoded as ascii is present, use that
# If there are additional properties, there must be length.
length_entry = self.env.begin().get("length".encode("ascii"))
if length_entry is not None:
num_entries = pickle.loads(length_entry)
else:
# Get the number of stores data from the number of entries
# in the LMDB
num_entries = self.env.stat()["entries"]

self._keys = list(range(num_entries))
self.num_samples = num_entries

#Get portion of total dataset
self.sharded = False
if "shard" in self.config and "total_shards" in self.config:
self.sharded = True
self.indices = range(self.num_samples)
# split all available indices into 'total_shards' bins
self.shards = np.array_split(
self.indices, self.config.get("total_shards", 1)
)
# limit each process to see a subset of data based off defined shard
self.available_indices = self.shards[self.config.get("shard", 0)]
self.num_samples = len(self.available_indices)

#TODO
self.transform = transform

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
# if sharding, remap idx to appropriate idx of the sharded set
if self.sharded:
idx = self.available_indices[idx]

#!CHECK, _keys should be less then total numbers of keys as there are more properties.
datapoint_pickled = self.env.begin().get(
f"{self._keys[idx]}".encode("ascii")
)

data_object = pickle.loads(datapoint_pickled)

#TODO
if self.transform is not None:
data_object = self.transform(data_object)

return data_object

def connect_db(self, lmdb_path=None):
env = lmdb.open(
str(lmdb_path),
subdir=False,
readonly=True,
lock=False,
readahead=True,
meminit=False,
max_readers=1,
)
return env

def close_db(self):
if not self.path.is_file():
for env in self.envs:
env.close()
else:
self.env.close()

def get_metadata(self, num_samples=100):
pass

@property
def dtype(self):
dtype = self.env.begin().get("dtype".encode("ascii"))
return pickle.loads(dtype)

@property
def feature_size(self):
feature_size = self.env.begin().get("feature_size".encode("ascii"))
return pickle.loads(feature_size)

@property
def feature_name(self):
feature_name = self.env.begin().get("feature_name".encode("ascii"))
return pickle.loads(feature_name)

@property
def mean(self):
mean = self.env.begin().get("mean".encode("ascii"))
return pickle.loads(mean)

@property
def std(self):
std = self.env.begin().get("std".encode("ascii"))
return pickle.loads(std)


def divide_to_list(a, b):
quotient = a // b
remainder = a % b

result = []
for i in range(b):
increment = 1 if i < remainder else 0
result.append(quotient + increment)

return result

def cleanup_lmdb_files(directory, pattern):
"""
Cleans up files matching the given pattern in the specified directory.
"""
file_list = glob.glob(os.path.join(directory, pattern))

for file_path in file_list:
try:
os.remove(file_path)
print(f"Deleted file: {file_path}")
except OSError as e:
print(f"Error deleting file: {file_path}. {str(e)}")

def CRNs2lmdb( dtype,
feature_size,
feature_name,
mean,
std,
lmdb_dir,
lmdb_name
):

#os.makedirs(os.path.join(lmdb_dir, exist_ok=True))
os.makedirs(lmdb_dir, exist_ok=True)

db_paths = os.path.join(lmdb_dir, "_tmp_data.%04d.lmdb")

meta_keys = {
"dtype" : dtype,
"feature_size": feature_size,
"feature_name": feature_name,
"mean" : mean,
"std" : std
}

# Merge LMDB files
merge_lmdbs(db_paths, lmdb_dir, lmdb_name)
cleanup_lmdb_files(lmdb_dir, "_tmp_data*")



def write_crns_to_lmdb(mp_args):
#pid is idx of workers.
db_path, samples, pid, meta_keys = mp_args

db = lmdb.open(
db_path,
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)

pbar = tqdm(
total=len(samples),
position=pid,
desc=f"Worker {pid}: Writing CRNs Objects into LMDBs",
)

#write indexed samples
idx = 0
for sample in samples:
txn=db.begin(write=True)
txn.put(
f"{idx}".encode("ascii"),
pickle.dumps(sample, protocol=-1),
)
idx += 1
pbar.update(1)
txn.commit()

#write properties
txn=db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(len(samples), protocol=-1))
txn.commit()

for key, value in meta_keys.items():
txn=db.begin(write=True)
txn.put(key.encode("ascii"), pickle.dumps(value, protocol=-1))
txn.commit()

db.sync()
db.close()


def merge_lmdbs(db_paths, out_path, output_file):
"""
merge lmdb files and reordering indexes.
"""
env_out = lmdb.open(
os.path.join(out_path, output_file),
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)


idx = 0
for db_path in db_paths:
env_in = lmdb.open(
str(db_path),
subdir=False,
readonly=True,
lock=False,
readahead=True,
meminit=False,
)

#should set indexes so that properties do not writtent down as well.
with env_out.begin(write=True) as txn_out, env_in.begin(write=False) as txn_in:
cursor = txn_in.cursor()
for key, value in cursor:
#write indexed samples
try:
int(key.decode("ascii"))
txn_out.put(
f"{idx}".encode("ascii"),
value,
)
idx+=1
#print(idx)
#write properties
except ValueError:
txn_out.put(
key,
value
)
env_in.close()

#update length
txn_out=env_out.begin(write=True)
txn_out.put("length".encode("ascii"), pickle.dumps(idx, protocol=-1))
txn_out.commit()

env_out.sync()
env_out.close()



def write_to_lmdb(new_samples, current_length, lmdb_update, db_path):

# #pid is idx of workers.
# db_path, samples, pid, meta_keys = mp_args
db = lmdb.open(
db_path,
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)

pbar = tqdm(
total=len(new_samples),
desc=f"Adding new samples into LMDBs",
)

#write indexed samples
idx = current_length
for sample in new_samples:
txn=db.begin(write=True)
txn.put(
f"{idx}".encode("ascii"),
pickle.dumps(sample, protocol=-1),
)
idx += 1
pbar.update(1)
txn.commit()

#write properties

total_length = current_length + len(new_samples)

txn=db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(total_length, protocol=-1))
txn.commit()

#write mean, std, feature_size, feature_name. dtype etc.
for key, value in lmdb_update.items():
txn=db.begin(write=True)
txn.put(key.encode("ascii"), pickle.dumps(value, protocol=-1))
txn.commit()

db.sync()
db.close()
Loading

0 comments on commit 539d552

Please sign in to comment.