diff --git a/src/codemodder/codetf.py b/src/codemodder/codetf.py index c67acbd5..331cb06a 100644 --- a/src/codemodder/codetf.py +++ b/src/codemodder/codetf.py @@ -21,18 +21,27 @@ from codemodder.context import CodemodExecutionContext -class Action(Enum): +class CaseInsensitiveEnum(str, Enum): + @classmethod + def _missing_(cls, value: object): + if not isinstance(value, str): + return super()._missing_(value) + + return cls.__members__.get(value.upper()) + + +class Action(CaseInsensitiveEnum): ADD = "add" REMOVE = "remove" -class PackageResult(Enum): +class PackageResult(CaseInsensitiveEnum): COMPLETED = "completed" FAILED = "failed" SKIPPED = "skipped" -class DiffSide(Enum): +class DiffSide(CaseInsensitiveEnum): LEFT = "left" RIGHT = "right" diff --git a/tests/test_codetf.py b/tests/test_codetf.py index 0bf7ed22..5c9bab57 100644 --- a/tests/test_codetf.py +++ b/tests/test_codetf.py @@ -4,6 +4,7 @@ import jsonschema import pytest import requests +from pydantic import ValidationError from codemodder.codetf import Change, ChangeSet, CodeTF, DiffSide, Reference, Result @@ -130,3 +131,39 @@ def test_write_codetf_with_results(tmpdir, mocker, codetf_schema): def test_reference_use_url_for_description(): ref = Reference(url="https://example.com") assert ref.description == "https://example.com" + + +def test_case_insensitive_change_validation(): + json = { + "lineNumber": 1, + "description": "Change 1 to 2", + "diffSide": "RIGHT", + "packageActions": [ + { + "action": "ADD", + "package": "foo", + "result": "COMPLETED", + } + ], + } + + Change.model_validate(json) + + +@pytest.mark.parametrize("bad_value", ["MIDDLE", "middle"]) +def test_still_invalidates_bad_value(bad_value): + json = { + "lineNumber": 1, + "description": "Change 1 to 2", + "diffSide": bad_value, + "packageActions": [ + { + "action": "ADD", + "package": "foo", + "result": "COMPLETED", + } + ], + } + + with pytest.raises(ValidationError): + Change.model_validate(json)