Skip to content

Commit

Permalink
Adding additional include support in spark component (#38537)
Browse files Browse the repository at this point in the history
Support additional include in spark component
  • Loading branch information
achauhan-scc authored Nov 15, 2024
1 parent c812e66 commit a0b7966
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 3 deletions.
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
## 1.23.0 (unreleased)

### Features Added
- Add support for additional include in spark component.

### Bugs Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy

import yaml
from marshmallow import INCLUDE, fields, post_load
from marshmallow import INCLUDE, fields, post_dump, post_load

from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
from azure.ai.ml._schema.component.component import ComponentSchema
Expand All @@ -20,6 +20,16 @@

class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
additional_includes = fields.List(fields.Str())

@post_dump
def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
if (
component_schema_dict.get("additional_includes") is not None
and len(component_schema_dict["additional_includes"]) == 0
):
component_schema_dict.pop("additional_includes")
return component_schema_dict


class RestSparkComponentSchema(SparkComponentSchema):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin
from .._util import convert_ordered_dict_to_dict, validate_attribute_type
from .._validation import MutableValidationResult
from .code import ComponentCodeMixin
from ._additional_includes import AdditionalIncludesMixin
from .component import Component


class SparkComponent(
Component, ParameterizedSpark, SparkJobEntryMixin, ComponentCodeMixin
Component, ParameterizedSpark, SparkJobEntryMixin, AdditionalIncludesMixin
): # pylint: disable=too-many-instance-attributes
"""Spark component version, used to define a Spark Component or Job.
Expand Down Expand Up @@ -79,6 +79,8 @@ class SparkComponent(
:paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
:keyword args: The arguments for the job. Defaults to None.
:paramtype args: Optional[str]
:keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
:paramtype additional_includes: Optional[List[str]]
.. admonition:: Example:
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
args: Optional[str] = None,
additional_includes: Optional[List] = None,
**kwargs: Any,
) -> None:
# validate init params are valid type
Expand All @@ -134,6 +137,7 @@ def __init__(
self.conf = conf
self.environment = environment
self.args = args
self.additional_includes = additional_includes or []
# For pipeline spark job, we also allow user to set driver_cores, driver_memory and so on by setting conf.
# If root level fields are not set by user, we promote conf setting to root level to facilitate subsequent
# verification. This usually happens when we use to_component(SparkJob) or builder function spark() as a node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def test_component_load(self):
}
assert spark_component.args == "--file_input ${{inputs.file_input}}"

def test_component_load_with_additional_include(self):
# code is specified in yaml, value is respected
component_yaml = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml"
spark_component = load_component(
component_yaml,
)

assert (
isinstance(spark_component.additional_includes, list)
and spark_component.additional_includes[0] == "common_src"
)

def test_spark_component_to_dict(self):
# Test optional params exists in component dict
yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml"
Expand All @@ -37,6 +49,14 @@ def test_spark_component_to_dict(self):
spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path)
assert spark_component._other_parameter.get("mock_option_param") == yaml_dict["mock_option_param"]

def test_spark_component_to_dict_additional_include(self):
# Test optional params exists in component dict
yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml"
yaml_dict = load_yaml(yaml_path)
yaml_dict["additional_includes"] = ["common_src"]
spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path)
assert spark_component.additional_includes[0] == yaml_dict["additional_includes"][0]

def test_spark_component_entity(self):
component = SparkComponent(
name="add_greeting_column_spark_component",
Expand Down Expand Up @@ -73,6 +93,39 @@ def test_spark_component_entity(self):

assert component_dict == yaml_component_dict

def test_spark_component_entity_additional_include(self):
component = SparkComponent(
name="wordcount_spark_component",
display_name="Spark word count",
description="Spark word count",
version="3",
inputs={
"file_input": {"type": "uri_file", "mode": "direct"},
},
driver_cores=1,
driver_memory="2g",
executor_cores=2,
executor_memory="2g",
executor_instances=4,
entry={"file": "wordcount.py"},
args="--input1 ${{inputs.file_input}}",
base_path="./tests/test_configs/components",
additional_includes=["common_src"],
)
omit_fields = [
"properties.component_spec.$schema",
"properties.component_spec._source",
"properties.properties.client_component_hash",
]
component_dict = component._to_rest_object().as_dict()
component_dict = pydash.omit(component_dict, *omit_fields)

yaml_path = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml"
yaml_component = load_component(yaml_path)
yaml_component_dict = yaml_component._to_rest_object().as_dict()
yaml_component_dict = pydash.omit(yaml_component_dict, *omit_fields)
assert component_dict == yaml_component_dict

def test_spark_component_version_as_a_function_with_inputs(self):
expected_rest_component = {
"type": "spark",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
$schema: https://azuremlschemas.azureedge.net/latest/sparkComponent.schema.json
name: wordcount_spark_component
type: spark
version: 3
display_name: Spark word count
description: Spark word count


inputs:
file_input:
type: uri_file
mode: direct

entry:
file: wordcount.py

args: >-
--input1 ${{inputs.file_input}}
conf:
spark.driver.cores: 1
spark.driver.memory: "2g"
spark.executor.cores: 2
spark.executor.memory: "2g"
spark.executor.instances: 4
additional_includes:
- common_src

0 comments on commit a0b7966

Please sign in to comment.