diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 6780ff91e..62fbfaceb 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -8,9 +8,9 @@ from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ - FastMRIWorkload as JaxWorkload + FastMRILayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PytWorkload + FastMRILayerNormWorkload as PytWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 6780ff91e..47bad372a 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -8,9 +8,9 @@ from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ - FastMRIWorkload as JaxWorkload + FastMRITanhWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PytWorkload + FastMRITanhWorkload as PytWorkload from tests.modeldiffs.diff import out_diff