Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 21, 2024
1 parent 6adb98f commit 7fe3952
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions ci/tests/test_hyena_dna/test_hyena_dna_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from helical.models.hyena_dna.model import HyenaDNA,HyenaDNAConfig
from helical.models.hyena_dna.model import HyenaDNAConfig
import pytest

@pytest.mark.parametrize("model_name, d_model, d_inner", [
Expand All @@ -15,10 +15,9 @@ def test_hyena_dna__valid_model_names(model_name, d_model, d_inner):
d_inner (int): The dimensionality of the inner layers.
"""
configurer = HyenaDNAConfig(model_name=model_name)
model = HyenaDNA(configurer=configurer)
assert model.config["model_name"] == model_name
assert model.config["d_model"] == d_model
assert model.config["d_inner"] == d_inner
assert configurer.config["model_path"].name == f"{model_name}.ckpt"
assert configurer.config["d_model"] == d_model
assert configurer.config["d_inner"] == d_inner

@pytest.mark.parametrize("model_name", [
("wrong_name")
Expand Down

0 comments on commit 7fe3952

Please sign in to comment.