Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisblake committed Nov 18, 2024
1 parent 14bd371 commit d68a9f8
Showing 1 changed file with 57 additions and 2 deletions.
59 changes: 57 additions & 2 deletions tests/aeroval/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pyaerocom.aeroval.collections import ObsCollection, ModelCollection
from pyaerocom.aeroval.obsentry import ObsEntry
from pyaerocom.aeroval.modelentry import ModelEntry
import pytest


def test_obscollection():
def test_obscollection_init_and_add_entry():
oc = ObsCollection()
oc.add_entry("model1", dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column"))
assert oc
Expand All @@ -19,7 +22,42 @@ def test_obscollection():
assert "AN-EEA-MP" in oc.keylist()


def test_modelcollection():
def test_obscollection_add_and_get_entry():
collection = ObsCollection()
entry = ObsEntry(obs_id="obs1", obs_vars=("var1",))
collection.add_entry("key1", entry)
retrieved_entry = collection.get_entry("key1")
assert retrieved_entry == entry


def test_obscollection_add_and_remove_entry():
collection = ObsCollection()
entry = ObsEntry(obs_id="obs1", obs_vars=("var1",))
collection.add_entry("key1", entry)
collection.remove_entry("key1")
with pytest.raises(KeyError):
collection.get_entry("key1")


def test_obscollection_get_web_interface_name():
collection = ObsCollection()
entry = ObsEntry(obs_id="obs1", obs_vars=("var1",), web_interface_name="web_name")
collection.add_entry("key1", entry)
assert collection.get_web_interface_name("key1") == "web_name"


def test_obscollection_all_vert_types():
collection = ObsCollection()
entry1 = ObsEntry(
obs_id="obs1", obs_vars=("var1",), obs_vert_type="Surface"
) # Assuming ObsEntry has an obs_vert_type attribute
entry2 = ObsEntry(obs_id="obs2", obs_vars=("var2",), obs_vert_type="Profile")
collection.add_entry("key1", entry1)
collection.add_entry("key2", entry2)
assert set(collection.all_vert_types) == {"Surface", "Profile"}


def test_modelcollection_init_and_add_entry():
mc = ModelCollection()
mc.add_entry("model1", dict(model_id="bla", obs_vars="od550aer", obs_vert_type="Column"))
assert mc
Expand All @@ -34,3 +72,20 @@ def test_modelcollection():
)

assert "ECMWF_OSUITE" in mc.keylist()


def test_modelcollection_add_and_get_entry():
collection = ModelCollection()
entry = ModelEntry(model_id="mod1")
collection.add_entry("key1", entry)
retrieved_entry = collection.get_entry("key1")
assert retrieved_entry == entry


def test_modelcollection_add_and_remove_entry():
collection = ModelCollection()
entry = ModelEntry(model_id="obs1")
collection.add_entry("key1", entry)
collection.remove_entry("key1")
with pytest.raises(KeyError):
collection.get_entry("key1")

0 comments on commit d68a9f8

Please sign in to comment.