diff --git a/ci/tests/test_hyena_dna/test_hyena_dna_model.py b/ci/tests/test_hyena_dna/test_hyena_dna_model.py index a4228262..93b6956c 100644 --- a/ci/tests/test_hyena_dna/test_hyena_dna_model.py +++ b/ci/tests/test_hyena_dna/test_hyena_dna_model.py @@ -1,4 +1,4 @@ -from helical.models.hyena_dna.model import HyenaDNA,HyenaDNAConfig +from helical.models.hyena_dna.model import HyenaDNAConfig import pytest @pytest.mark.parametrize("model_name, d_model, d_inner", [ @@ -15,10 +15,9 @@ def test_hyena_dna__valid_model_names(model_name, d_model, d_inner): d_inner (int): The dimensionality of the inner layers. """ configurer = HyenaDNAConfig(model_name=model_name) - model = HyenaDNA(configurer=configurer) - assert model.config["model_name"] == model_name - assert model.config["d_model"] == d_model - assert model.config["d_inner"] == d_inner + assert configurer.config["model_path"].name == f"{model_name}.ckpt" + assert configurer.config["d_model"] == d_model + assert configurer.config["d_inner"] == d_inner @pytest.mark.parametrize("model_name", [ ("wrong_name")