From 3635f11304d410378d9d79b9f2396673e86a70fd Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Wed, 21 Aug 2024 12:06:09 +0200 Subject: [PATCH 1/3] Fixed optional typing on non-serializable types --- src/zenml/steps/entrypoint_function_utils.py | 50 +++++-- tests/unit/steps/test_base_step_new.py | 135 ++++++++++++++++++- 2 files changed, 170 insertions(+), 15 deletions(-) diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 2ed6c53f5fc..68b99b17023 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -25,9 +25,11 @@ Sequence, Type, Union, + get_args, + get_origin, ) -from pydantic import ConfigDict, ValidationError, create_model +from pydantic import ValidationError, create_model from zenml.constants import ENFORCE_TYPE_ANNOTATIONS from zenml.exceptions import StepInterfaceError @@ -185,7 +187,6 @@ def validate_input(self, key: str, value: Any) -> None: ) parameter = self.inputs[key] - if isinstance( value, ( @@ -235,17 +236,40 @@ def _validate_input_value( parameter: The function parameter for which the value was provided. value: The input value. """ - config_dict = ConfigDict(arbitrary_types_allowed=False) - - # Create a pydantic model with just a single required field with the - # type annotation of the parameter to verify the input type including - # pydantics type coercion - validation_model_class = create_model( - "input_validation_model", - __config__=config_dict, - value=(parameter.annotation, ...), - ) - validation_model_class(value=value) + annotation = parameter.annotation + + # Handle Optional types + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + if type(None) in args: + if value is None: + return # None is valid for Optional types + # Remove NoneType from args as this case is handled from here + args = tuple(arg for arg in args if arg is not type(None)) + annotation = args[0] if len(args) == 1 else Union[args] + + # Handle None values for non-Optional types + if value is None and annotation is not type(None): + raise ValueError(f"Expected {annotation}, but got None") + + # Use Pydantic for all types to take advantage of its coercion abilities + try: + config_dict = {"arbitrary_types_allowed": True} + validation_model_class = create_model( + "input_validation_model", + __config__=type("Config", (), config_dict), + value=(annotation, ...), + ) + validation_model_class(value=value) + except ValidationError as e: + raise ValueError(f"Invalid input: {e}") + except Exception: + # If Pydantic can't handle it, fall back to isinstance + if not isinstance(value, annotation): + raise TypeError( + f"Expected {annotation}, but got {type(value)}" + ) def validate_entrypoint_function( diff --git a/tests/unit/steps/test_base_step_new.py b/tests/unit/steps/test_base_step_new.py index 96c9c3abab9..bb91cb918fd 100644 --- a/tests/unit/steps/test_base_step_new.py +++ b/tests/unit/steps/test_base_step_new.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. import sys from contextlib import ExitStack as does_not_raise -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple, Union import pytest from pydantic import BaseModel @@ -50,7 +50,7 @@ def test_input_validation_inside_pipeline(): def test_pipeline(step_input): return step_with_int_input(step_input) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): test_pipeline(step_input="wrong_type") with does_not_raise(): @@ -187,3 +187,134 @@ def test_pipeline(): with pytest.raises(StepInterfaceError): test_pipeline() + + +# ------------------------ Optional Input types +@step +def some_step(some_optional_int: Optional[int]) -> None: + pass + + +def test_step_can_have_optional_input_types(): + """Tests that a step allows None values for optional input types""" + + @pipeline + def p(): + some_step(some_optional_int=None) + + with does_not_raise(): + p() + + +def test_step_fails_on_none_inputs_for_non_optional_input_types(): + """Tests that a step does not allow None values for non-optional input types""" + + @step + def some_step(some_optional_int: int) -> None: + pass + + @pipeline + def p(): + some_step(some_optional_int=None) + + with pytest.raises(ValueError): + p().run(unlisted=True) + + +# --------- Test type coercion + + +@step +def coerce_step(some_int: int, some_float: float) -> None: + pass + + +def test_step_with_type_coercion(): + """Tests that a step can coerce types when possible""" + + @pipeline + def p(): + coerce_step(some_int="42", some_float="3.14") + + with does_not_raise(): + p() + + +def test_step_fails_on_invalid_type_coercion(): + """Tests that a step fails when type coercion is not possible""" + + @step + def coerce_step(some_int: int) -> None: + pass + + @pipeline + def p(): + coerce_step(some_int="not an int") + + with pytest.raises(ValueError): + p().run(unlisted=True) + + +# ------------- Non-Json-Serializable types + + +class NonSerializable: + def __init__(self, value): + self.value = value + + +@step +def non_serializable_step(some_obj: NonSerializable) -> None: + pass + + +def test_step_with_non_serializable_type_as_parameter_fails(): + """Tests that a step can handle non-JSON serializable types, but fails if these are passed as parameters""" + + @pipeline + def p(): + non_serializable_step(some_obj=NonSerializable(42)) + + with pytest.raises(StepInterfaceError): + p().run(unlisted=True) + + +def test_step_fails_on_wrong_non_serializable_type(): + """Tests that a step fails when given the wrong non-serializable type""" + + @step + def non_serializable_step(some_obj: NonSerializable) -> None: + pass + + @pipeline + def p(): + non_serializable_step(some_obj=None) + + with pytest.raises(ValueError): + p().run(unlisted=True) + + +# --------- Test union types + + +@step +def union_step(some_union: Union[int, str]) -> None: + pass + + +def test_step_with_union_type(): + """Tests that a step can handle Union types""" + + @pipeline + def p(): + union_step(some_union=42) + + with does_not_raise(): + p() + + @pipeline + def p(): + union_step(some_union="fourtytwo") + + with does_not_raise(): + p() From ba67f401ff48de89761a1beab92aa203650bdd50 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Tue, 27 Aug 2024 14:12:50 +0200 Subject: [PATCH 2/3] Applied reviews and simplified the fix --- src/zenml/steps/entrypoint_function_utils.py | 24 +++----------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 68b99b17023..c7b67f7d0c2 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -25,11 +25,9 @@ Sequence, Type, Union, - get_args, - get_origin, ) -from pydantic import ValidationError, create_model +from pydantic import ValidationError, create_model, ConfigDict from zenml.constants import ENFORCE_TYPE_ANNOTATIONS from zenml.exceptions import StepInterfaceError @@ -187,6 +185,7 @@ def validate_input(self, key: str, value: Any) -> None: ) parameter = self.inputs[key] + if isinstance( value, ( @@ -238,32 +237,15 @@ def _validate_input_value( """ annotation = parameter.annotation - # Handle Optional types - origin = get_origin(annotation) - if origin is Union: - args = get_args(annotation) - if type(None) in args: - if value is None: - return # None is valid for Optional types - # Remove NoneType from args as this case is handled from here - args = tuple(arg for arg in args if arg is not type(None)) - annotation = args[0] if len(args) == 1 else Union[args] - - # Handle None values for non-Optional types - if value is None and annotation is not type(None): - raise ValueError(f"Expected {annotation}, but got None") - # Use Pydantic for all types to take advantage of its coercion abilities try: - config_dict = {"arbitrary_types_allowed": True} + config_dict = ConfigDict(arbitrary_types_allowed=True) validation_model_class = create_model( "input_validation_model", __config__=type("Config", (), config_dict), value=(annotation, ...), ) validation_model_class(value=value) - except ValidationError as e: - raise ValueError(f"Invalid input: {e}") except Exception: # If Pydantic can't handle it, fall back to isinstance if not isinstance(value, annotation): From 04e627a3f1656178ae3344da807489fb63852114 Mon Sep 17 00:00:00 2001 From: AlexejPenner Date: Tue, 27 Aug 2024 14:22:38 +0200 Subject: [PATCH 3/3] Adjusted test cases --- tests/unit/steps/test_base_step_new.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/steps/test_base_step_new.py b/tests/unit/steps/test_base_step_new.py index bb91cb918fd..f7aedff08ba 100644 --- a/tests/unit/steps/test_base_step_new.py +++ b/tests/unit/steps/test_base_step_new.py @@ -189,6 +189,7 @@ def test_pipeline(): test_pipeline() + # ------------------------ Optional Input types @step def some_step(some_optional_int: Optional[int]) -> None: @@ -217,7 +218,7 @@ def some_step(some_optional_int: int) -> None: def p(): some_step(some_optional_int=None) - with pytest.raises(ValueError): + with pytest.raises(TypeError): p().run(unlisted=True) @@ -290,7 +291,7 @@ def non_serializable_step(some_obj: NonSerializable) -> None: def p(): non_serializable_step(some_obj=None) - with pytest.raises(ValueError): + with pytest.raises(TypeError): p().run(unlisted=True)