From 34ec602a16e85614c09153206440082cc3565948 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 29 Nov 2023 16:51:05 +0100 Subject: [PATCH] remove duplicated code --- ...st_pointer_network_for_joint_taskmodule.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_joint_taskmodule.py b/tests/taskmodules/test_pointer_network_for_joint_taskmodule.py index 41b0fc7ea..76901afbc 100644 --- a/tests/taskmodules/test_pointer_network_for_joint_taskmodule.py +++ b/tests/taskmodules/test_pointer_network_for_joint_taskmodule.py @@ -15,29 +15,26 @@ # FIXTURES_DIR = FIXTURES_ROOT / "taskmodules" / "gmam_taskmodule" +DUMP_FIXTURE_DATA = False + def _config_to_str(cfg: Dict[str, str]) -> str: result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)]) return result -CONFIGS = [ - {"span_end_mode": "first_token_of_last_word"}, - {"span_end_mode": "last_token"}, -] -CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - -DUMP_FIXTURE_DATA = False +CONFIGS = [{}, {"partition_layer_name": "sentences"}] +CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} -@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) -def config(request): - return CONFIGS_DICT[request.param] +@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) +def config_str(request): + return request.param @pytest.fixture(scope="module") -def config_str(config): - return _config_to_str(config) +def config(config_str): + return CONFIG_DICT[config_str] @pytest.fixture(scope="module") @@ -87,20 +84,6 @@ def test_document(document): assert str(sentences[1]) == "Trust me." -CONFIGS = [{}, {"partition_layer_name": "sentences"}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - @pytest.fixture(scope="module") def taskmodule(document, config): taskmodule = PointerNetworkForJointTaskModule(