Skip to content

Commit

Permalink
Working on updating unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisiacovella committed Sep 15, 2023
1 parent 914d339 commit 7b27580
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
12 changes: 7 additions & 5 deletions modelforge/dataset/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class QM9Dataset(HDF5Dataset):
_property_names = PropertyNames(
"atomic_numbers",
"geometry",
"return_energy",
"internal_energy_at_0K",
)

_available_properties = [
"geometry",
"atomic_numbers",
"return_energy",
"internal_energy_at_0K",
] # NOTE: Any way to set this automatically?

def __init__(
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
_default_properties_of_interest = [
"geometry",
"atomic_numbers",
"return_energy",
"internal_energy_at_0K",
] # NOTE: Default values

self._properties_of_interest = _default_properties_of_interest
Expand All @@ -82,8 +82,10 @@ def __init__(
)
self.dataset_name = dataset_name
self.for_unit_testing = for_unit_testing
self.test_id = "17oZ07UOxv2fkEmu-d5mLk6aGIuhV0mJ7"
self.full_id = "1_bSdQjEvI67Tk_LKYbW0j8nmggnb5MoU"
# self.test_id = "17oZ07UOxv2fkEmu-d5mLk6aGIuhV0mJ7"
# self.full_id = "1_bSdQjEvI67Tk_LKYbW0j8nmggnb5MoU"
self.test_id = "18C9Iq_7VZLx0gZbJYje8X6tybZb5m3JY"
self.full_id = "1damjPgjKviTogDJ2UJvhYjyBZxGvRPP-"

@property
def properties_of_interest(self) -> List[str]:
Expand Down
10 changes: 5 additions & 5 deletions modelforge/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,25 @@ def test_different_properties_of_interest(dataset):
assert data.properties_of_interest == [
"geometry",
"atomic_numbers",
"return_energy",
"internal_energy_at_0K",
]

dataset = factory.create_dataset(data)
raw_data_item = dataset[0]
assert isinstance(raw_data_item, dict)
assert len(raw_data_item) == 4

data.properties_of_interest = ["return_energy", "geometry"]
data.properties_of_interest = ["internal_energy_at_0K", "geometry"]
assert data.properties_of_interest == [
"return_energy",
"internal_energy_at_0K",
"geometry",
]

dataset = factory.create_dataset(data)
raw_data_item = dataset[0]
print(raw_data_item)
assert isinstance(raw_data_item, dict)
assert len(raw_data_item) != 3
assert len(raw_data_item) != 3


@pytest.mark.parametrize("dataset", DATASETS)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_dataset_generation(dataset):
pass

# the dataloader automatically splits and batches the dataset
# for the trianing set it batches the 80 datapoints in
# for the training set it batches the 80 datapoints in
# a batch of 64 and a batch of 16 samples
assert len(train_dataloader) == 2 # nr of batches
v = [v_ for v_ in train_dataloader]
Expand Down

0 comments on commit 7b27580

Please sign in to comment.