diff --git a/ci/tests/test_hyena_dna/test_model.py b/ci/tests/test_hyena_dna/test_hyena_dna_model.py similarity index 100% rename from ci/tests/test_hyena_dna/test_model.py rename to ci/tests/test_hyena_dna/test_hyena_dna_model.py diff --git a/ci/tests/test_uce/test_uce_model.py b/ci/tests/test_uce/test_uce_model.py new file mode 100644 index 00000000..52339b89 --- /dev/null +++ b/ci/tests/test_uce/test_uce_model.py @@ -0,0 +1,36 @@ +from helical.models.uce.model import UCE, UCEConfig +import pytest + +@pytest.mark.parametrize("model_name, n_layers", [ + ("33l_8ep_1024t_1280", 33), + ("4layer_model", 4) +]) +def test_uce__valid_model_names(model_name, n_layers): + """ + Test case for the UCE class initialization. + + Args: + model_name (str): The name of the model. + n_layers (int): The number of layers of the model. + """ + configurer = UCEConfig(model_name=model_name) + model = UCE(configurer=configurer) + assert model.config["model_name"] == model_name + assert model.config["n_layers"] == n_layers + +@pytest.mark.parametrize("model_name", [ + ("wrong_name") +]) +def test_uce__invalid_model_names(model_name): + """ + Test case when an invalid model name is provided. + Verifies that a ValueError is raised when an invalid model name is passed to the UCEConfig constructor. + + Parameters: + - model_name (str): The invalid model name. + + Raises: + - ValueError: If the model name is invalid. + """ + with pytest.raises(ValueError): + UCEConfig(model_name=model_name)