diff --git a/sdk/ml/azure-ai-ml/CHANGELOG.md b/sdk/ml/azure-ai-ml/CHANGELOG.md index fdb2dc673393..c4db1221e125 100644 --- a/sdk/ml/azure-ai-ml/CHANGELOG.md +++ b/sdk/ml/azure-ai-ml/CHANGELOG.md @@ -2,6 +2,7 @@ ## 1.23.0 (unreleased) ### Features Added + - Add support for additional include in spark component. ### Bugs Fixed diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py index d5c48016a898..87ec85fc0be7 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py @@ -7,7 +7,7 @@ from copy import deepcopy import yaml -from marshmallow import INCLUDE, fields, post_load +from marshmallow import INCLUDE, fields, post_dump, post_load from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema from azure.ai.ml._schema.component.component import ComponentSchema @@ -20,6 +20,16 @@ class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema): type = StringTransformedEnum(allowed_values=[NodeType.SPARK]) + additional_includes = fields.List(fields.Str()) + + @post_dump + def remove_unnecessary_fields(self, component_schema_dict, **kwargs): + if ( + component_schema_dict.get("additional_includes") is not None + and len(component_schema_dict["additional_includes"]) == 0 + ): + component_schema_dict.pop("additional_includes") + return component_schema_dict class RestSparkComponentSchema(SparkComponentSchema): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/spark_component.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/spark_component.py index dac831cfb0c7..7da65fb69732 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/spark_component.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/spark_component.py @@ -17,12 +17,12 @@ from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin from .._util import convert_ordered_dict_to_dict, validate_attribute_type from .._validation import MutableValidationResult -from .code import ComponentCodeMixin +from ._additional_includes import AdditionalIncludesMixin from .component import Component class SparkComponent( - Component, ParameterizedSpark, SparkJobEntryMixin, ComponentCodeMixin + Component, ParameterizedSpark, SparkJobEntryMixin, AdditionalIncludesMixin ): # pylint: disable=too-many-instance-attributes """Spark component version, used to define a Spark Component or Job. @@ -79,6 +79,8 @@ class SparkComponent( :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]] :keyword args: The arguments for the job. Defaults to None. :paramtype args: Optional[str] + :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None. + :paramtype additional_includes: Optional[List[str]] .. admonition:: Example: @@ -112,6 +114,7 @@ def __init__( inputs: Optional[Dict] = None, outputs: Optional[Dict] = None, args: Optional[str] = None, + additional_includes: Optional[List] = None, **kwargs: Any, ) -> None: # validate init params are valid type @@ -134,6 +137,7 @@ def __init__( self.conf = conf self.environment = environment self.args = args + self.additional_includes = additional_includes or [] # For pipeline spark job, we also allow user to set driver_cores, driver_memory and so on by setting conf. # If root level fields are not set by user, we promote conf setting to root level to facilitate subsequent # verification. This usually happens when we use to_component(SparkJob) or builder function spark() as a node diff --git a/sdk/ml/azure-ai-ml/tests/component/unittests/test_spark_component_entity.py b/sdk/ml/azure-ai-ml/tests/component/unittests/test_spark_component_entity.py index b7ea4845fad8..679c6741e3c3 100644 --- a/sdk/ml/azure-ai-ml/tests/component/unittests/test_spark_component_entity.py +++ b/sdk/ml/azure-ai-ml/tests/component/unittests/test_spark_component_entity.py @@ -29,6 +29,18 @@ def test_component_load(self): } assert spark_component.args == "--file_input ${{inputs.file_input}}" + def test_component_load_with_additional_include(self): + # code is specified in yaml, value is respected + component_yaml = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml" + spark_component = load_component( + component_yaml, + ) + + assert ( + isinstance(spark_component.additional_includes, list) + and spark_component.additional_includes[0] == "common_src" + ) + def test_spark_component_to_dict(self): # Test optional params exists in component dict yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml" @@ -37,6 +49,14 @@ def test_spark_component_to_dict(self): spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path) assert spark_component._other_parameter.get("mock_option_param") == yaml_dict["mock_option_param"] + def test_spark_component_to_dict_additional_include(self): + # Test optional params exists in component dict + yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml" + yaml_dict = load_yaml(yaml_path) + yaml_dict["additional_includes"] = ["common_src"] + spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path) + assert spark_component.additional_includes[0] == yaml_dict["additional_includes"][0] + def test_spark_component_entity(self): component = SparkComponent( name="add_greeting_column_spark_component", @@ -73,6 +93,39 @@ def test_spark_component_entity(self): assert component_dict == yaml_component_dict + def test_spark_component_entity_additional_include(self): + component = SparkComponent( + name="wordcount_spark_component", + display_name="Spark word count", + description="Spark word count", + version="3", + inputs={ + "file_input": {"type": "uri_file", "mode": "direct"}, + }, + driver_cores=1, + driver_memory="2g", + executor_cores=2, + executor_memory="2g", + executor_instances=4, + entry={"file": "wordcount.py"}, + args="--input1 ${{inputs.file_input}}", + base_path="./tests/test_configs/components", + additional_includes=["common_src"], + ) + omit_fields = [ + "properties.component_spec.$schema", + "properties.component_spec._source", + "properties.properties.client_component_hash", + ] + component_dict = component._to_rest_object().as_dict() + component_dict = pydash.omit(component_dict, *omit_fields) + + yaml_path = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml" + yaml_component = load_component(yaml_path) + yaml_component_dict = yaml_component._to_rest_object().as_dict() + yaml_component_dict = pydash.omit(yaml_component_dict, *omit_fields) + assert component_dict == yaml_component_dict + def test_spark_component_version_as_a_function_with_inputs(self): expected_rest_component = { "type": "spark", diff --git a/sdk/ml/azure-ai-ml/tests/test_configs/components/hello_spark_component_with_additional_include.yml b/sdk/ml/azure-ai-ml/tests/test_configs/components/hello_spark_component_with_additional_include.yml new file mode 100644 index 000000000000..52cd3f9ab484 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/test_configs/components/hello_spark_component_with_additional_include.yml @@ -0,0 +1,27 @@ +$schema: https://azuremlschemas.azureedge.net/latest/sparkComponent.schema.json +name: wordcount_spark_component +type: spark +version: 3 +display_name: Spark word count +description: Spark word count + + +inputs: + file_input: + type: uri_file + mode: direct + +entry: + file: wordcount.py + +args: >- + --input1 ${{inputs.file_input}} + +conf: + spark.driver.cores: 1 + spark.driver.memory: "2g" + spark.executor.cores: 2 + spark.executor.memory: "2g" + spark.executor.instances: 4 +additional_includes: + - common_src \ No newline at end of file