Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance stack validation #148

Merged
merged 7 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/mlstacks/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# permissions and limitations under the License.
"""MLStacks constants."""

from typing import Dict, List

MLSTACKS_PACKAGE_NAME = "mlstacks"
MLSTACKS_INITIALIZATION_FILE_FLAG = "IGNORE_ME"
MLSTACKS_STACK_COMPONENT_FLAGS = [
Expand Down Expand Up @@ -39,6 +41,52 @@
"model_deployer": ["seldon"],
"step_operator": ["sagemaker", "vertex"],
}
ALLOWED_COMPONENT_TYPES: Dict[str, Dict[str, List[str]]] = {
"aws": {
"artifact_store": ["s3"],
"container_registry": ["aws"],
"experiment_tracker": ["mlflow"],
"orchestrator": [
"kubeflow",
"kubernetes",
"sagemaker",
"skypilot",
"tekton",
],
"mlops_platform": ["zenml"],
"model_deployer": ["seldon"],
"step_operator": ["sagemaker"],
},
"azure": {},
"gcp": {
"artifact_store": ["gcp"],
"container_registry": ["gcp"],
"experiment_tracker": ["mlflow"],
"orchestrator": [
"kubeflow",
"kubernetes",
"skypilot",
"tekton",
"vertex",
],
"mlops_platform": ["zenml"],
"model_deployer": ["seldon"],
"step_operator": ["vertex"],
},
"k3d": {
"artifact_store": ["minio"],
"container_registry": ["default"],
"experiment_tracker": ["mlflow"],
"orchestrator": [
"kubeflow",
"kubernetes",
"sagemaker",
"tekton",
],
"mlops_platform": ["zenml"],
"model_deployer": ["seldon"],
},
}

PERMITTED_NAME_REGEX = r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$"
ANALYTICS_OPT_IN_ENV_VARIABLE = "MLSTACKS_ANALYTICS_OPT_IN"
Expand All @@ -49,5 +97,19 @@
"contain alphanumeric characters, underscores, and hyphens "
"thereafter."
)
INVALID_COMPONENT_TYPE_ERROR_MESSAGE = (
"Artifact Store, Container Registry, Experiment Tracker, Orchestrator, "
"MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d "
"providers. Step Operator may only be used with aws and gcp."
)
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE = (
"Only certain flavors are allowed for a given provider-component type "
"combination. For more information, consult the tables for your specified "
"provider at the MLStacks documentation: "
"https://mlstacks.zenml.io/stacks/stack-specification."
)
STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE = (
"Stack provider and component provider mismatch."
)
DEFAULT_REMOTE_STATE_BUCKET_NAME = "zenml-mlstacks-remote-state"
TERRAFORM_CONFIG_BUCKET_REPLACEMENT_STRING = "BUCKETNAMEREPLACEME"
20 changes: 20 additions & 0 deletions src/mlstacks/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ComponentFlavorEnum(str, Enum):
TEKTON = "tekton"
VERTEX = "vertex"
ZENML = "zenml"
DEFAULT = "default"


class DeploymentMethodEnum(str, Enum):
Expand Down Expand Up @@ -77,3 +78,22 @@ class AnalyticsEventsEnum(str, Enum):
MLSTACKS_SOURCE = "MLStacks Source"
MLSTACKS_EXCEPTION = "MLStacks Exception"
MLSTACKS_VERSION = "MLStacks Version"


class SpecTypeEnum(str, Enum):
"""Spec type enum."""

STACK = "stack"
COMPONENT = "component"


class StackSpecVersionEnum(int, Enum):
"""Spec version enum."""

ONE = 1


class ComponentSpecVersionEnum(int, Enum):
"""Spec version enum."""

ONE = 1
77 changes: 70 additions & 7 deletions src/mlstacks/models/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
# permissions and limitations under the License.
"""Component model."""

from typing import Dict, Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel, validator

from mlstacks.constants import INVALID_NAME_ERROR_MESSAGE
from mlstacks.constants import (
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE,
INVALID_COMPONENT_TYPE_ERROR_MESSAGE,
INVALID_NAME_ERROR_MESSAGE,
)
from mlstacks.enums import (
ComponentFlavorEnum,
ComponentSpecVersionEnum,
ComponentTypeEnum,
ProviderEnum,
SpecTypeEnum,
)
from mlstacks.utils.model_utils import (
is_valid_component_flavor,
is_valid_component_type,
is_valid_name,
)
from mlstacks.utils.model_utils import is_valid_name


class ComponentMetadata(BaseModel):
Expand All @@ -49,16 +59,16 @@ class Component(BaseModel):
metadata: The metadata of the component.
"""

spec_version: int = 1
spec_type: str = "component"
spec_version: ComponentSpecVersionEnum = ComponentSpecVersionEnum.ONE
spec_type: SpecTypeEnum = SpecTypeEnum.COMPONENT
name: str
provider: ProviderEnum
component_type: ComponentTypeEnum
component_flavor: ComponentFlavorEnum
provider: ProviderEnum
metadata: Optional[ComponentMetadata] = None

@validator("name")
def validate_name(cls, name: str) -> str: # noqa: N805
def validate_name(cls, name: str) -> str: # noqa
"""Validate the name.

Name must start with an alphanumeric character and can only contain
Expand All @@ -78,3 +88,56 @@ def validate_name(cls, name: str) -> str: # noqa: N805
if not is_valid_name(name):
raise ValueError(INVALID_NAME_ERROR_MESSAGE)
return name

@validator("component_type")
def validate_component_type(
cls, # noqa
component_type: str,
values: Dict[str, Any],
) -> str:
"""Validate the component type.

Artifact Store, Container Registry, Experiment Tracker, Orchestrator,
MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d
providers. Step Operator may only be used with aws and gcp.

Args:
component_type: The component type.
values: The previously validated component specs.

Returns:
The validated component type.

Raises:
ValueError: If the component type is invalid.
"""
if not is_valid_component_type(component_type, values["provider"]):
raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE)
return component_type

@validator("component_flavor")
def validate_component_flavor(
cls, # noqa
component_flavor: str,
values: Dict[str, Any],
) -> str:
"""Validate the component flavor.

Only certain flavors are allowed for a given provider-component
type combination. For more information, consult the tables for
your specified provider at the MLStacks documentation:
https://mlstacks.zenml.io/stacks/stack-specification.

Args:
component_flavor: The component flavor.
values: The previously validated component specs.

Returns:
The validated component flavor.

Raises:
ValueError: If the component flavor is invalid.
"""
if not is_valid_component_flavor(component_flavor, values):
raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE)
return component_flavor
8 changes: 5 additions & 3 deletions src/mlstacks/models/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from mlstacks.enums import (
DeploymentMethodEnum,
ProviderEnum,
SpecTypeEnum,
StackSpecVersionEnum,
)
from mlstacks.models.component import Component
from mlstacks.utils.model_utils import is_valid_name
Expand All @@ -38,8 +40,8 @@ class Stack(BaseModel):
components: The components of the stack.
"""

spec_version: int = 1
spec_type: str = "stack"
spec_version: StackSpecVersionEnum = StackSpecVersionEnum.ONE
spec_type: SpecTypeEnum = SpecTypeEnum.STACK
name: str
provider: ProviderEnum
default_region: Optional[str]
Expand All @@ -50,7 +52,7 @@ class Stack(BaseModel):
components: List[Component] = []

@validator("name")
def validate_name(cls, name: str) -> str: # noqa: N805
def validate_name(cls, name: str) -> str: # noqa
"""Validate the name.

Name must start with an alphanumeric character and can only contain
Expand Down
46 changes: 45 additions & 1 deletion src/mlstacks/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"""Util functions for Pydantic models and validation."""

import re
from typing import Any, Dict

from mlstacks.constants import PERMITTED_NAME_REGEX
from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX


def is_valid_name(name: str) -> bool:
Expand All @@ -29,3 +30,46 @@ def is_valid_name(name: str) -> bool:
True if the name is valid, False otherwise.
"""
return re.match(PERMITTED_NAME_REGEX, name) is not None


def is_valid_component_type(component_type: str, provider: str) -> bool:
"""Check if the component type is valid.

Used for components.

Args:
component_type: The component type.
provider: The provider.

Returns:
True if the component type is valid, False otherwise.
"""
allowed_types = list(ALLOWED_COMPONENT_TYPES[provider].keys())
return component_type in allowed_types


def is_valid_component_flavor(
component_flavor: str, specs: Dict[str, Any]
) -> bool:
"""Check if the component flavor is valid.

Used for components.

Args:
component_flavor: The component flavor.
specs: The previously validated component specs.

Returns:
True if the component flavor is valid, False otherwise.
"""
try:
is_valid = (
component_flavor
in ALLOWED_COMPONENT_TYPES[specs["provider"]][
specs["component_type"]
]
)
except KeyError:
return False

return is_valid
25 changes: 22 additions & 3 deletions src/mlstacks/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import yaml

from mlstacks.constants import STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE
from mlstacks.models.component import (
Component,
ComponentMetadata,
Expand Down Expand Up @@ -57,9 +58,17 @@ def load_component_yaml(path: str) -> Component:

Returns:
The component model.

Raises:
FileNotFoundError: If the file is not found.
"""
with open(path) as file:
component_data = yaml.safe_load(file)
try:
with open(path) as file:
component_data = yaml.safe_load(file)
except FileNotFoundError as exc:
error_message = f"""Component file at "{path}" specified in
the stack spec file could not be found."""
raise FileNotFoundError(error_message) from exc

if component_data.get("metadata") is None:
component_data["metadata"] = {}
Expand Down Expand Up @@ -88,14 +97,18 @@ def load_stack_yaml(path: str) -> Stack:

Returns:
The stack model.

Raises:
ValueError: If the stack and component have different providers
"""
with open(path) as yaml_file:
stack_data = yaml.safe_load(yaml_file)
component_data = stack_data.get("components")

if component_data is None:
component_data = []
return Stack(

stack = Stack(
spec_version=stack_data.get("spec_version"),
spec_type=stack_data.get("spec_type"),
name=stack_data.get("name"),
Expand All @@ -107,3 +120,9 @@ def load_stack_yaml(path: str) -> Stack:
load_component_yaml(component) for component in component_data
],
)

for component in stack.components:
if component.provider != stack.provider:
raise ValueError(STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE)

return stack
Loading
Loading