diff --git a/.bazelrc b/.bazelrc index 6c1123c5..1ef20a99 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,8 +1,8 @@ # Common Default # Wrapper to make sure tests are run. -# Allow at most 3 hours for eternal tests. -test --run_under='//bazel:test_wrapper' --test_timeout=-1,-1,-1,10800 +# Allow at most 4 hours for eternal tests. +test --run_under='//bazel:test_wrapper' --test_timeout=-1,-1,-1,14400 # Since integration tests are located in different packages than code under test, # the default instrumentation filter would exclude the code under test. This diff --git a/CHANGELOG.md b/CHANGELOG.md index f880471a..ff594211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,20 @@ # Release History -## 1.7.1 +## 1.7.2 + +### Bug Fixes + +- Model Explainability: Fix issue that explain is enabled for scikit-learn pipeline +whose task is UNKNOWN and fails later when invoked. + +### Behavior Changes + +### New Features + +- Registry: Support asynchronous model inference service creation with the `block` option + in `ModelVersion.create_service()` set to True by default. + +## 1.7.1 (2024-11-05) ### Bug Fixes diff --git a/README.md b/README.md index b81c8891..12ca0558 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ and deployment process, and includes two key components. ### Snowpark ML Development -[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-development) +[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#ml-modeling) provides a collection of python APIs enabling efficient ML model development directly in Snowflake: 1. Modeling API (`snowflake.ml.modeling`) for data preprocessing, feature engineering and model training in Snowflake. @@ -26,14 +26,21 @@ their native data loader formats. 1. FileSet API: FileSet provides a Python fsspec-compliant API for materializing data into a Snowflake internal stage from a query or Snowpark Dataframe along with a number of convenience APIs. -### Snowpark Model Management [Public Preview] +### Snowflake MLOps -[Snowpark Model Management](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-ops) complements -the Snowpark ML Development API, and provides model management capabilities along with integrated deployment into Snowflake. +Snowflake MLOps contains suit of tools and objects to make ML development cycle. It complements +the Snowpark ML Development API, and provides end to end development to deployment within Snowflake. Currently, the API consists of: -1. Registry: A python API for managing models within Snowflake which also supports deployment of ML models into Snowflake -as native MODEL object running with Snowflake Warehouse. +1. [Registry](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-model-registry): A python API + allows secure deployment and management of models in Snowflake, supporting models trained both inside and outside of + Snowflake. +2. [Feature Store](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-feature-store): A fully + integrated solution for defining, managing, storing and discovering ML features derived from your data. The + Snowflake Feature Store supports automated, incremental refresh from batch and streaming data sources, so that + feature pipelines need be defined only once to be continuously updated with new data. +3. [Datasets](https://docs.snowflake.com/developer-guide/snowflake-ml/overview#snowflake-datasets): Dataset provide an + immutable, versioned snapshot of your data suitable for ingestion by your machine learning models. ## Getting started @@ -80,3 +87,19 @@ conda install \ Note that until a `snowflake-ml-python` package version is available in the official Snowflake conda channel, there may be compatibility issues. Server-side functionality that `snowflake-ml-python` depends on may not yet be released. + +### Verifying the package + +1. Install cosign. + This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go). +1. Download the file from the repository like [pypi](https://pypi.org/project/snowflake-ml-python/#files). +1. Download the signature files from the [release tag](https://github.com/snowflakedb/snowflake-ml-python/releases/tag/1.7.0). +1. Verify signature on projects signed using Jenkins job: + + ```sh + cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0.tar.gz.sig + + cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0 + ``` + +NOTE: Version 1.7.0 is used as example here. Please choose the the latest version. diff --git a/bazel/environments/conda-env-snowflake.yml b/bazel/environments/conda-env-snowflake.yml index ae87b83b..f271999b 100644 --- a/bazel/environments/conda-env-snowflake.yml +++ b/bazel/environments/conda-env-snowflake.yml @@ -36,6 +36,7 @@ dependencies: - protobuf==3.20.3 - psutil==5.9.0 - pyarrow==10.0.1 + - pyjwt==2.8.0 - pytest-rerunfailures==12.0 - pytest-xdist==3.5.0 - pytest==7.4.0 diff --git a/bazel/environments/conda-env.yml b/bazel/environments/conda-env.yml index 6e102884..bcf829d6 100644 --- a/bazel/environments/conda-env.yml +++ b/bazel/environments/conda-env.yml @@ -36,6 +36,7 @@ dependencies: - protobuf==3.20.3 - psutil==5.9.0 - pyarrow==10.0.1 + - pyjwt==2.8.0 - pytest-rerunfailures==12.0 - pytest-xdist==3.5.0 - pytest==7.4.0 diff --git a/bazel/environments/conda-gpu-env.yml b/bazel/environments/conda-gpu-env.yml index 00d8958e..e4170ad3 100755 --- a/bazel/environments/conda-gpu-env.yml +++ b/bazel/environments/conda-gpu-env.yml @@ -37,6 +37,7 @@ dependencies: - protobuf==3.20.3 - psutil==5.9.0 - pyarrow==10.0.1 + - pyjwt==2.8.0 - pytest-rerunfailures==12.0 - pytest-xdist==3.5.0 - pytest==7.4.0 diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index a8dba137..65bd495c 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.7.1 + version: 1.7.2 requirements: build: - python @@ -35,6 +35,7 @@ requirements: - packaging>=20.9,<25 - pandas>=1.0.0,<3 - pyarrow + - pyjwt>=2.0.0, <3 - pytimeparse>=1.1.8,<2 - pyyaml>=6.0,<7 - requests diff --git a/ci/targets/quarantine/prod3.txt b/ci/targets/quarantine/prod3.txt index d725a7ca..2055a3e6 100644 --- a/ci/targets/quarantine/prod3.txt +++ b/ci/targets/quarantine/prod3.txt @@ -5,4 +5,3 @@ //tests/integ/snowflake/ml/modeling/preprocessing:k_bins_discretizer_test //tests/integ/snowflake/ml/modeling/linear_model:logistic_regression_test //tests/integ/snowflake/ml/registry/model:registry_mlflow_model_test -//tests/integ/snowflake/ml/registry/services/... diff --git a/docs/README.md b/docs/README.md index 3d88e47c..b380ac05 100644 --- a/docs/README.md +++ b/docs/README.md @@ -32,4 +32,5 @@ The following files are in the `docs/source` directory: - `index.rst`: ReStructuredText (RST) file that will be built as the index page. It mainly as a landing point and indicates the subp-ackages to include in the API reference. Currently these include the Modeling and FileSet/FileSystem APIs. -- `fileset.rst`, `modeling.rst`, `registry.rst`: RST files that direct Sphinx to include the specific classes in each submodule. +- RST files that direct Sphinx to include the specific classes in each submodule. + - `fileset.rst`, `modeling.rst`, `monitoring.rst`, `registry.rst` diff --git a/docs/source/index.rst b/docs/source/index.rst index fc3bf458..0636c4fa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,4 +32,5 @@ Table of Contents fileset model modeling + monitoring registry diff --git a/docs/source/monitoring.rst b/docs/source/monitoring.rst new file mode 100644 index 00000000..dec8ad09 --- /dev/null +++ b/docs/source/monitoring.rst @@ -0,0 +1,31 @@ +=========================== +snowflake.ml.monitoring +=========================== + +.. automodule:: snowflake.ml.monitoring + :noindex: + +snowflake.ml.monitoring.model_monitor +------------------------------------- + +.. currentmodule:: snowflake.ml.monitoring.model_monitor + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/monitoring + + ModelMonitor + +snowflake.ml.monitoring.entities +------------------------------------- + +.. currentmodule:: snowflake.ml.monitoring.entities + +.. rubric:: Classes + +.. autosummary:: + :toctree: api/monitoring + + model_monitor_config.ModelMonitorConfig + model_monitor_config.ModelMonitorSourceConfig diff --git a/requirements.txt b/requirements.txt index 452f792a..36bec55d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ peft==0.5.0 protobuf==3.20.3 psutil==5.9.0 pyarrow==10.0.1 +pyjwt==2.8.0 pytest-rerunfailures==12.0 pytest-xdist==3.5.0 pytest==7.4.0 diff --git a/requirements.yml b/requirements.yml index 6b335ce6..af4bb7e3 100644 --- a/requirements.yml +++ b/requirements.yml @@ -174,6 +174,9 @@ - name: pyarrow dev_version: 10.0.1 version_requirements: '' +- name: pyjwt + dev_version: 2.8.0 + version_requirements: '>=2.0.0, <3' - name: pytest dev_version: 7.4.0 tags: diff --git a/snowflake/ml/_internal/utils/BUILD.bazel b/snowflake/ml/_internal/utils/BUILD.bazel index ed417956..57c5c369 100644 --- a/snowflake/ml/_internal/utils/BUILD.bazel +++ b/snowflake/ml/_internal/utils/BUILD.bazel @@ -249,3 +249,8 @@ py_test( "//snowflake/ml/test_utils:mock_session", ], ) + +py_library( + name = "jwt_generator", + srcs = ["jwt_generator.py"], +) diff --git a/snowflake/ml/_internal/utils/jwt_generator.py b/snowflake/ml/_internal/utils/jwt_generator.py new file mode 100644 index 00000000..dd6e5753 --- /dev/null +++ b/snowflake/ml/_internal/utils/jwt_generator.py @@ -0,0 +1,141 @@ +import base64 +import hashlib +import logging +from datetime import datetime, timedelta, timezone +from typing import Optional + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import types + +logger = logging.getLogger(__name__) + +ISSUER = "iss" +EXPIRE_TIME = "exp" +ISSUE_TIME = "iat" +SUBJECT = "sub" + + +class JWTGenerator: + """ + Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator + keeps the generated token and only regenerates the token if a specified period of time has passed. + """ + + _DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime + _DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes + ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256 + + def __init__( + self, + account: str, + user: str, + private_key: types.PRIVATE_KEY_TYPES, + lifetime: Optional[timedelta] = None, + renewal_delay: Optional[timedelta] = None, + ) -> None: + """ + Create a new JWTGenerator object. + + Args: + account: The account identifier. + user: The username. + private_key: The private key used to sign the JWT. + lifetime: The lifetime of the token. + renewal_delay: The time before the token expires to renew it. + """ + + # Construct the fully qualified name of the user in uppercase. + self.account = JWTGenerator._prepare_account_name_for_jwt(account) + self.user = user.upper() + self.qualified_username = self.account + "." + self.user + self.private_key = private_key + self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key) + + self.issuer = self.qualified_username + "." + self.public_key_fp + self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME + self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA + self.renew_time = datetime.now(timezone.utc) + self.token: Optional[str] = None + + logger.info( + """Creating JWTGenerator with arguments + account : %s, user : %s, lifetime : %s, renewal_delay : %s""", + self.account, + self.user, + self.lifetime, + self.renewal_delay, + ) + + @staticmethod + def _prepare_account_name_for_jwt(raw_account: str) -> str: + account = raw_account + if ".global" not in account: + # Handle the general case. + idx = account.find(".") + if idx > 0: + account = account[0:idx] + else: + # Handle the replication case. + idx = account.find("-") + if idx > 0: + account = account[0:idx] + # Use uppercase for the account identifier. + return account.upper() + + def get_token(self) -> str: + now = datetime.now(timezone.utc) # Fetch the current time + if self.token is not None and self.renew_time > now: + return self.token + + # If the token has expired or doesn't exist, regenerate the token. + logger.info( + "Generating a new token because the present time (%s) is later than the renewal time (%s)", + now, + self.renew_time, + ) + # Calculate the next time we need to renew the token. + self.renew_time = now + self.renewal_delay + + # Create our payload + payload = { + # Set the issuer to the fully qualified username concatenated with the public key fingerprint. + ISSUER: self.issuer, + # Set the subject to the fully qualified username. + SUBJECT: self.qualified_username, + # Set the issue time to now. + ISSUE_TIME: now, + # Set the expiration time, based on the lifetime specified for this object. + EXPIRE_TIME: now + self.lifetime, + } + + # Regenerate the actual token + token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) + # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string. + # If the token is a byte string, convert it to a string. + if isinstance(token, bytes): + token = token.decode("utf-8") + self.token = token + logger.info( + "Generated a JWT with the following payload: %s", + jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]), + ) + + return token + + @staticmethod + def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str: + # Get the raw bytes of public key. + public_key_raw = private_key.public_key().public_bytes( + serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # Get the sha256 hash of the raw bytes. + sha256hash = hashlib.sha256() + sha256hash.update(public_key_raw) + + # Base64-encode the value and prepend the prefix 'SHA256:'. + public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8") + logger.info("Public key fingerprint is %s", public_key_fp) + + return public_key_fp diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 7ea9c0ab..4ee6ff4c 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -14,7 +14,7 @@ from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._packager.model_handlers import snowmlmodel -from snowflake.snowpark import Session, dataframe +from snowflake.snowpark import Session, async_job, dataframe _TELEMETRY_PROJECT = "MLOps" _TELEMETRY_SUBPROJECT = "ModelManagement" @@ -631,7 +631,8 @@ def create_service( max_batch_rows: Optional[int] = None, force_rebuild: bool = False, build_external_access_integration: Optional[str] = None, - ) -> str: + block: bool = True, + ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. Args: @@ -659,6 +660,9 @@ def create_service( force_rebuild: Whether to force a model inference image rebuild. build_external_access_integration: (Deprecated) The external access integration for image build. This is usually permitting access to conda & PyPI repositories. + block: A bool value indicating whether this function will wait until the service is available. + When it is ``False``, this function executes the underlying service creation asynchronously + and returns an :class:`AsyncJob`. """ ... @@ -679,7 +683,8 @@ def create_service( max_batch_rows: Optional[int] = None, force_rebuild: bool = False, build_external_access_integrations: Optional[List[str]] = None, - ) -> str: + block: bool = True, + ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. Args: @@ -707,6 +712,9 @@ def create_service( force_rebuild: Whether to force a model inference image rebuild. build_external_access_integrations: The external access integrations for image build. This is usually permitting access to conda & PyPI repositories. + block: A bool value indicating whether this function will wait until the service is available. + When it is ``False``, this function executes the underlying service creation asynchronously + and returns an :class:`AsyncJob`. """ ... @@ -742,7 +750,8 @@ def create_service( force_rebuild: bool = False, build_external_access_integration: Optional[str] = None, build_external_access_integrations: Optional[List[str]] = None, - ) -> str: + block: bool = True, + ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. Args: @@ -772,12 +781,16 @@ def create_service( usually permitting access to conda & PyPI repositories. build_external_access_integrations: The external access integrations for image build. This is usually permitting access to conda & PyPI repositories. + block: A bool value indicating whether this function will wait until the service is available. + When it is False, this function executes the underlying service creation asynchronously + and returns an AsyncJob. Raises: ValueError: Illegal external access integration arguments. Returns: - Result information about service creation from server. + If `block=True`, return result information about service creation from server. + Otherwise, return the service creation AsyncJob. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, @@ -829,6 +842,7 @@ def create_service( if build_external_access_integrations is None else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations] ), + block=block, statement_params=statement_params, ) diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 30eef9a5..db71e464 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -738,6 +738,7 @@ def test_create_service(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=["EAI"], + block=True, ) mock_create_service.assert_called_once_with( database_name=None, @@ -761,6 +762,7 @@ def test_create_service(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=[sql_identifier.SqlIdentifier("EAI")], + block=True, statement_params=mock.ANY, ) @@ -778,6 +780,7 @@ def test_create_service_same_pool(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=["EAI"], + block=True, ) mock_create_service.assert_called_once_with( database_name=None, @@ -801,6 +804,7 @@ def test_create_service_same_pool(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=[sql_identifier.SqlIdentifier("EAI")], + block=True, statement_params=mock.ANY, ) @@ -818,6 +822,7 @@ def test_create_service_no_eai(self) -> None: num_workers=1, max_batch_rows=1024, force_rebuild=True, + block=True, ) mock_create_service.assert_called_once_with( database_name=None, @@ -841,6 +846,50 @@ def test_create_service_no_eai(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=None, + block=True, + statement_params=mock.ANY, + ) + + def test_create_service_async_job(self) -> None: + with mock.patch.object(self.m_mv._service_ops, "create_service") as mock_create_service: + self.m_mv.create_service( + service_name="SERVICE", + image_build_compute_pool="IMAGE_BUILD_COMPUTE_POOL", + service_compute_pool="SERVICE_COMPUTE_POOL", + image_repo="IMAGE_REPO", + max_instances=3, + cpu_requests="CPU", + memory_requests="MEMORY", + gpu_requests="GPU", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=["EAI"], + block=False, + ) + mock_create_service.assert_called_once_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier(self.m_mv.model_name), + version_name=sql_identifier.SqlIdentifier(self.m_mv.version_name), + service_database_name=None, + service_schema_name=None, + service_name=sql_identifier.SqlIdentifier("SERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=None, + image_repo_schema_name=None, + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=False, + max_instances=3, + cpu_requests="CPU", + memory_requests="MEMORY", + gpu_requests="GPU", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EAI")], + block=False, statement_params=mock.ANY, ) diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 06771725..6f58f9a5 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -6,7 +6,7 @@ import tempfile import threading import time -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from packaging import version @@ -15,7 +15,7 @@ from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier from snowflake.ml.model._client.service import model_deployment_spec from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql -from snowflake.snowpark import exceptions, row, session +from snowflake.snowpark import async_job, exceptions, row, session from snowflake.snowpark._internal import utils as snowpark_utils module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY) @@ -107,8 +107,9 @@ def create_service( max_batch_rows: Optional[int], force_rebuild: bool, build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]], + block: bool, statement_params: Optional[Dict[str, Any]] = None, - ) -> str: + ) -> Union[str, async_job.AsyncJob]: # Fall back to the registry's database and schema if not provided database_name = database_name or self._database_name @@ -204,11 +205,15 @@ def create_service( log_thread = self._start_service_log_streaming( async_job, services, model_inference_service_exists, force_rebuild, statement_params ) - log_thread.join() - res = cast(str, cast(List[row.Row], async_job.result())[0][0]) - module_logger.info(f"Inference service {service_name} deployment complete: {res}") - return res + if block: + log_thread.join() + + res = cast(str, cast(List[row.Row], async_job.result())[0][0]) + module_logger.info(f"Inference service {service_name} deployment complete: {res}") + return res + else: + return async_job def _start_service_log_streaming( self, diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py index a6043456..c6988164 100644 --- a/snowflake/ml/model/_client/ops/service_ops_test.py +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -67,6 +67,7 @@ def test_create_service(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=True, statement_params=self.m_statement_params, ) mock_create_stage.assert_called_once_with( @@ -159,6 +160,7 @@ def test_create_service_model_db_and_schema(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=True, statement_params=self.m_statement_params, ) mock_create_stage.assert_called_once_with( @@ -251,6 +253,7 @@ def test_create_service_default_db_and_schema(self) -> None: max_batch_rows=1024, force_rebuild=True, build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=True, statement_params=self.m_statement_params, ) mock_create_stage.assert_called_once_with( @@ -307,6 +310,47 @@ def test_create_service_default_db_and_schema(self) -> None: statement_params=self.m_statement_params, ) + def test_create_service_async_job(self) -> None: + with mock.patch.object(self.m_ops._stage_client, "create_tmp_stage",), mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), mock.patch.object(self.m_ops._model_deployment_spec, "save",), mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ), mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), mock.MagicMock(spec=snowpark.AsyncJob)), + ), mock.patch.object( + self.m_ops._service_client, + "get_service_status", + return_value=(service_sql.ServiceStatus.PENDING, None), + ): + res = self.m_ops.create_service( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), + image_repo_schema_name=sql_identifier.SqlIdentifier("IMAGE_REPO_SCHEMA"), + image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + ingress_enabled=True, + max_instances=1, + cpu_requests="1", + memory_requests="6GiB", + gpu_requests="1", + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=False, + statement_params=self.m_statement_params, + ) + self.assertIsInstance(res, snowpark.AsyncJob) + def test_get_model_build_service_name(self) -> None: query_id = "01b6fc10-0002-c121-0000-6ed10736311e" """ diff --git a/snowflake/ml/model/_client/sql/stage.py b/snowflake/ml/model/_client/sql/stage.py index b993f142..c645448e 100644 --- a/snowflake/ml/model/_client/sql/stage.py +++ b/snowflake/ml/model/_client/sql/stage.py @@ -15,6 +15,6 @@ def create_tmp_stage( ) -> None: query_result_checker.SqlResultValidator( self._session, - f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}", + f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}", statement_params=statement_params, ).has_dimensions(expected_rows=1, expected_cols=1).validate() diff --git a/snowflake/ml/model/_client/sql/stage_test.py b/snowflake/ml/model/_client/sql/stage_test.py index 6bd3ccf3..55a13fca 100644 --- a/snowflake/ml/model/_client/sql/stage_test.py +++ b/snowflake/ml/model/_client/sql/stage_test.py @@ -18,7 +18,7 @@ def test_create_tmp_stage(self) -> None: m_df = mock_data_frame.MockDataFrame( collect_result=[Row("Stage MODEL successfully created.")], collect_statement_params=m_statement_params ) - self.m_session.add_mock_sql("""CREATE TEMPORARY STAGE TEMP."test".MODEL""", copy.deepcopy(m_df)) + self.m_session.add_mock_sql("""CREATE SCOPED TEMPORARY STAGE TEMP."test".MODEL""", copy.deepcopy(m_df)) c_session = cast(Session, self.m_session) stage_sql.StageSQLClient( c_session, @@ -31,7 +31,7 @@ def test_create_tmp_stage(self) -> None: statement_params=m_statement_params, ) - self.m_session.add_mock_sql("""CREATE TEMPORARY STAGE TEMP."test".MODEL""", copy.deepcopy(m_df)) + self.m_session.add_mock_sql("""CREATE SCOPED TEMPORARY STAGE TEMP."test".MODEL""", copy.deepcopy(m_df)) c_session = cast(Session, self.m_session) stage_sql.StageSQLClient( c_session, diff --git a/snowflake/ml/model/_packager/model_handlers/sklearn.py b/snowflake/ml/model/_packager/model_handlers/sklearn.py index 9e814b7b..673a45a1 100644 --- a/snowflake/ml/model/_packager/model_handlers/sklearn.py +++ b/snowflake/ml/model/_packager/model_handlers/sklearn.py @@ -164,6 +164,8 @@ def get_prediction( stacklevel=1, ) enable_explainability = False + elif model_meta.task == model_types.Task.UNKNOWN: + enable_explainability = False else: enable_explainability = True if enable_explainability: diff --git a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py index f2f7987e..00fb903d 100644 --- a/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py +++ b/snowflake/ml/model/_packager/model_handlers_test/sklearn_test.py @@ -348,6 +348,33 @@ def test_skl_no_default_explain_without_background_data(self) -> None: assert callable(predict_method) self.assertEqual(explain_method, None) + def test_skl_no_default_explain_sklearn_pipeline(self) -> None: + iris_X, iris_y = datasets.load_iris(return_X_y=True) + regr = linear_model.LinearRegression() + pipe = Pipeline([("regr", regr)]) + # The pipeline can be used as any other estimator + # and avoids leaking the test set into the train set + pipe.fit(iris_X, iris_y) + with tempfile.TemporaryDirectory() as tmpdir: + model_packager.ModelPackager(os.path.join(tmpdir, "model1")).save( + name="model1", + model=pipe, + sample_input_data=iris_X, + metadata={"author": "halu", "version": "1"}, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + pk = model_packager.ModelPackager(os.path.join(tmpdir, "model1")) + pk.load(as_custom_model=True) + assert pk.model + assert pk.meta + predict_method = getattr(pk.model, "predict", None) + explain_method = getattr(pk.model, "explain", None) + assert callable(predict_method) + self.assertEqual(explain_method, None) + def test_skl_with_cr_estimator(self) -> None: class SecondMockEstimator: ... diff --git a/snowflake/ml/model/_signatures/utils.py b/snowflake/ml/model/_signatures/utils.py index 8ff11308..e2e93210 100644 --- a/snowflake/ml/model/_signatures/utils.py +++ b/snowflake/ml/model/_signatures/utils.py @@ -118,7 +118,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any]) category=DeprecationWarning, stacklevel=1, ) - return core.ModelSignature( inputs=[ core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)), diff --git a/snowflake/ml/monitoring/BUILD.bazel b/snowflake/ml/monitoring/BUILD.bazel index 931aa8d0..bc2ec8ba 100644 --- a/snowflake/ml/monitoring/BUILD.bazel +++ b/snowflake/ml/monitoring/BUILD.bazel @@ -32,6 +32,7 @@ py_library( "model_monitor.py", ], deps = [ + ":model_monitor_version", "//snowflake/ml/_internal:telemetry", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/monitoring/_client:model_monitor_sql_client", diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py index 29d3faf8..cf9460d7 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py @@ -1,6 +1,4 @@ -import typing -from collections import Counter -from typing import Any, Dict, List, Mapping, Optional, Set +from typing import Any, Dict, List, Mapping, Optional from snowflake import snowpark from snowflake.ml._internal.utils import ( @@ -10,27 +8,12 @@ table_manager, ) from snowflake.ml.model._client.sql import _base -from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.snowpark import session, types -SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA" - MODEL_JSON_COL_NAME = "model" MODEL_JSON_MODEL_NAME_FIELD = "model_name" MODEL_JSON_VERSION_NAME_FIELD = "version_name" -MONITOR_NAME_COL_NAME = "MONITOR_NAME" -SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME" -FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME" -VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME" -FUNCTION_NAME_COL_NAME = "FUNCTION_NAME" -TASK_COL_NAME = "TASK" -MONITORING_ENABLED_COL_NAME = "IS_ENABLED" -TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME" -PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES" -LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES" -ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES" - def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str: sql_list = ", ".join([f"'{column}'" for column in columns]) @@ -146,19 +129,6 @@ def show_model_monitors( .validate() ) - def _validate_unique_columns( - self, - timestamp_column: sql_identifier.SqlIdentifier, - id_columns: List[sql_identifier.SqlIdentifier], - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - ) -> None: - all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column] - num_all_columns = len(all_columns) - num_unique_columns = len(set(all_columns)) - if num_all_columns != num_unique_columns: - raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.") - def validate_existence_by_name( self, *, @@ -244,125 +214,6 @@ def _validate_columns_exist_in_source( if not all([column_name in source_column_schema for column_name in id_columns]): raise ValueError(f"ID column(s): {id_columns} do not exist in source.") - def _validate_timestamp_column_type( - self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier - ) -> None: - """Ensures columns have the same type. - - Args: - table_schema: Dictionary of column names and types in the source table. - timestamp_column: Name of the timestamp column. - - Raises: - ValueError: If the timestamp column is not of type TimestampType. - """ - if not isinstance(table_schema[timestamp_column], types.TimestampType): - raise ValueError( - f"Timestamp column: {timestamp_column} must be TimestampType. " - f"Found: {table_schema[timestamp_column]}" - ) - - def _validate_id_columns_types( - self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier] - ) -> None: - """Ensures id columns have the correct type. - - Args: - table_schema: Dictionary of column names and types in the source table. - id_columns: List of id column names. - - Raises: - ValueError: If the id column is not of type StringType. - """ - id_column_types = list({table_schema[column_name] for column_name in id_columns}) - all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types]) - if not all_id_columns_string: - raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}") - - def _validate_prediction_columns_types( - self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier] - ) -> None: - """Ensures prediction columns have the same type. - - Args: - table_schema: Dictionary of column names and types in the source table. - prediction_columns: List of prediction column names. - - Raises: - ValueError: If the prediction columns do not share the same type. - """ - - prediction_column_types = {table_schema[column_name] for column_name in prediction_columns} - if len(prediction_column_types) > 1: - raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}") - - def _validate_label_columns_types( - self, - table_schema: Mapping[str, types.DataType], - label_columns: List[sql_identifier.SqlIdentifier], - ) -> None: - """Ensures label columns have the same type, and the correct type for the score type. - - Args: - table_schema: Dictionary of column names and types in the source table. - label_columns: List of label column names. - - Raises: - ValueError: If the label columns do not share the same type. - """ - label_column_types = {table_schema[column_name] for column_name in label_columns} - if len(label_column_types) > 1: - raise ValueError(f"Label column types must be the same. Found: {label_column_types}") - - def _validate_column_types( - self, - *, - table_schema: Mapping[str, types.DataType], - timestamp_column: sql_identifier.SqlIdentifier, - id_columns: List[sql_identifier.SqlIdentifier], - prediction_columns: List[sql_identifier.SqlIdentifier], - label_columns: List[sql_identifier.SqlIdentifier], - ) -> None: - """Ensures columns have the expected type. - - Args: - table_schema: Dictionary of column names and types in the source table. - timestamp_column: Name of the timestamp column. - id_columns: List of id column names. - prediction_columns: List of prediction column names. - label_columns: List of label column names. - """ - self._validate_timestamp_column_type(table_schema, timestamp_column) - self._validate_id_columns_types(table_schema, id_columns) - self._validate_prediction_columns_types(table_schema, prediction_columns) - self._validate_label_columns_types(table_schema, label_columns) - # TODO(SNOW-1646693): Validate label makes sense with model task - - def _validate_source_table_features_shape( - self, - table_schema: Mapping[str, types.DataType], - special_columns: Set[sql_identifier.SqlIdentifier], - model_function: model_manifest_schema.ModelFunctionInfo, - ) -> None: - table_schema_without_special_columns = { - k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns - } - schema_column_types_to_count: typing.Counter[types.DataType] = Counter() - for column_type in table_schema_without_special_columns.values(): - schema_column_types_to_count[column_type] += 1 - - inputs = model_function["signature"].inputs - function_input_types = [input.as_snowpark_type() for input in inputs] - function_input_types_to_count: typing.Counter[types.DataType] = Counter() - for function_input_type in function_input_types: - function_input_types_to_count[function_input_type] += 1 - - if function_input_types_to_count != schema_column_types_to_count: - raise ValueError( - "Model function input types do not match the source table input columns types. " - f"Model function expected: {inputs} but got {table_schema_without_special_columns}" - ) - def validate_source( self, *, @@ -395,22 +246,6 @@ def validate_source( id_columns=id_columns, ) - def delete_monitor_metadata( - self, - name: str, - statement_params: Optional[Dict[str, Any]] = None, - ) -> None: - """Delete the row in the metadata table corresponding to the given monitor name. - - Args: - name: Name of the model monitor whose metadata should be deleted. - statement_params: Optional set of statement_params to include with query. - """ - self._sql_client._session.sql( - f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME} - WHERE {MONITOR_NAME_COL_NAME} = '{name}'""", - ).collect(statement_params=statement_params) - def _alter_monitor( self, operation: str, diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py index 7580e610..04aa58ba 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py @@ -59,34 +59,6 @@ def test_validate_source_table(self) -> None: ) self.m_session.finalize() - def test_validate_source_table_shape(self) -> None: - mocked_table_out = mock.MagicMock(name="schema") - self.m_session.table = mock.MagicMock(name="table", return_value=mocked_table_out) - mocked_table_out.schema = mock.MagicMock(name="schema") - mocked_table_out.schema.fields = [ - types.StructField(self.test_timestamp_column, types.TimestampType()), - types.StructField(self.test_prediction_column_name, types.DoubleType()), - types.StructField(self.test_label_column_name, types.DoubleType()), - types.StructField(self.test_id_column_name, types.StringType()), - types.StructField("feature1", types.StringType()), - ] - - self.monitor_sql_client.validate_source( - source_database=None, - source_schema=None, - source=self.test_source_table_name, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - id_columns=[sql_identifier.SqlIdentifier("ID")], - prediction_class_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - prediction_score_columns=[], - actual_score_columns=[sql_identifier.SqlIdentifier("LABEL")], - actual_class_columns=[], - ) - self.m_session.table.assert_called_once_with( - f"{self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name}" - ) - self.m_session.finalize() - def test_validate_monitor_warehouse(self) -> None: self.m_session.add_mock_sql( query=f"""SHOW WAREHOUSES LIKE '{self.test_wh_name}'""", @@ -194,95 +166,6 @@ def test_validate_columns_exist_in_source_table(self) -> None: id_columns=[sql_identifier.SqlIdentifier("ID")], ) - def test_validate_column_types(self) -> None: - self.monitor_sql_client._validate_column_types( - table_schema={ - "PREDICTION1": types.DoubleType(), - "PREDICTION2": types.DoubleType(), - "LABEL1": types.DoubleType(), - "LABEL2": types.DoubleType(), - "ID": types.StringType(), - "TIMESTAMP": types.TimestampType(types.TimestampTimeZone("ltz")), - }, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[ - sql_identifier.SqlIdentifier("PREDICTION1"), - sql_identifier.SqlIdentifier("PREDICTION2"), - ], - id_columns=[sql_identifier.SqlIdentifier("ID")], - label_columns=[sql_identifier.SqlIdentifier("LABEL1"), sql_identifier.SqlIdentifier("LABEL2")], - ) - - def test_validate_prediction_column_types(self) -> None: - with self.assertRaisesRegex(ValueError, "Prediction column types must be the same. Found: .*"): - self.monitor_sql_client._validate_prediction_columns_types( - table_schema={ - "PREDICTION1": types.DoubleType(), - "PREDICTION2": types.StringType(), - }, - prediction_columns=[ - sql_identifier.SqlIdentifier("PREDICTION1"), - sql_identifier.SqlIdentifier("PREDICTION2"), - ], - ) - - def test_validate_label_column_types(self) -> None: - with self.assertRaisesRegex(ValueError, "Label column types must be the same. Found:"): - self.monitor_sql_client._validate_label_columns_types( - table_schema={ - "LABEL1": types.DoubleType(), - "LABEL2": types.StringType(), - }, - label_columns=[sql_identifier.SqlIdentifier("LABEL1"), sql_identifier.SqlIdentifier("LABEL2")], - ) - - def test_validate_timestamp_column_type(self) -> None: - with self.assertRaisesRegex(ValueError, "Timestamp column: TIMESTAMP must be TimestampType"): - self.monitor_sql_client._validate_timestamp_column_type( - table_schema={ - "TIMESTAMP": types.StringType(), - }, - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - ) - - def test_validate_id_columns_types(self) -> None: - with self.assertRaisesRegex(ValueError, "Id columns must all be StringType"): - self.monitor_sql_client._validate_id_columns_types( - table_schema={ - "ID": types.DoubleType(), - }, - id_columns=[ - sql_identifier.SqlIdentifier("ID"), - ], - ) - - def test_validate_multiple_id_columns_types(self) -> None: - with self.assertRaisesRegex(ValueError, "Id columns must all be StringType. Found"): - self.monitor_sql_client._validate_id_columns_types( - table_schema={ - "ID1": types.StringType(), - "ID2": types.DecimalType(), - }, - id_columns=[ - sql_identifier.SqlIdentifier("ID1"), - sql_identifier.SqlIdentifier("ID2"), - ], - ) - - def test_validate_id_columns_types_all_string(self) -> None: - self.monitor_sql_client._validate_id_columns_types( - table_schema={ - "ID1": types.StringType(36), - "ID2": types.StringType(64), - "ID3": types.StringType(), - }, - id_columns=[ - sql_identifier.SqlIdentifier("ID1"), - sql_identifier.SqlIdentifier("ID2"), - sql_identifier.SqlIdentifier("ID3"), - ], - ) - def test_validate_existence_by_name(self) -> None: self.m_session.add_mock_sql( query=f"SHOW MODEL MONITORS LIKE '{self.test_monitor_name}' IN {self.test_db_name}.{self.test_schema_name}", @@ -314,29 +197,6 @@ def test_validate_existence_by_name(self) -> None: self.assertTrue(res) self.m_session.finalize() - def test_validate_unique_columns(self) -> None: - self.monitor_sql_client._validate_unique_columns( - id_columns=[sql_identifier.SqlIdentifier("ID")], - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[sql_identifier.SqlIdentifier("PREDICTION")], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - - def test_validate_unique_columns_column_used_twice(self) -> None: - with self.assertRaisesRegex( - ValueError, "Column names must be unique across id, timestamp, prediction, and label columns." - ): - self.monitor_sql_client._validate_unique_columns( - id_columns=[sql_identifier.SqlIdentifier("ID")], - timestamp_column=sql_identifier.SqlIdentifier("TIMESTAMP"), - prediction_columns=[ - sql_identifier.SqlIdentifier("PREDICTION"), - # This is a duplicate with the id column - sql_identifier.SqlIdentifier("ID"), - ], - label_columns=[sql_identifier.SqlIdentifier("LABEL")], - ) - def test_suspend_monitor(self) -> None: self.m_session.add_mock_sql( f"""ALTER MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} SUSPEND""", @@ -353,7 +213,6 @@ def test_resume_monitor(self) -> None: self.monitor_sql_client.resume_monitor(self.test_monitor_name) self.m_session.finalize() - # TODO: Move to new test class def test_drop_model_monitor(self) -> None: self.m_session.add_mock_sql( f"""DROP MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name}""", diff --git a/snowflake/ml/monitoring/_manager/model_monitor_manager.py b/snowflake/ml/monitoring/_manager/model_monitor_manager.py index 1a8d8e4a..2abfefa9 100644 --- a/snowflake/ml/monitoring/_manager/model_monitor_manager.py +++ b/snowflake/ml/monitoring/_manager/model_monitor_manager.py @@ -14,15 +14,6 @@ class ModelMonitorManager: """Class to manage internal operations for Model Monitor workflows.""" - def _validate_task_from_model_version( - self, - model_version: model_version_impl.ModelVersion, - ) -> type_hints.Task: - task = model_version.get_model_task() - if task == type_hints.Task.UNKNOWN: - raise ValueError("Registry model must be logged with task in order to be monitored.") - return task - def __init__( self, session: session.Session, @@ -51,6 +42,15 @@ def __init__( schema_name=self._schema_name, ) + def _validate_task_from_model_version( + self, + model_version: model_version_impl.ModelVersion, + ) -> type_hints.Task: + task = model_version.get_model_task() + if task == type_hints.Task.UNKNOWN: + raise ValueError("Registry model must be logged with task in order to be monitored.") + return task + def _validate_model_function_from_model_version( self, function: str, model_version: model_version_impl.ModelVersion ) -> None: diff --git a/snowflake/ml/monitoring/entities/BUILD.bazel b/snowflake/ml/monitoring/entities/BUILD.bazel index 99ee6c9d..ca08431d 100644 --- a/snowflake/ml/monitoring/entities/BUILD.bazel +++ b/snowflake/ml/monitoring/entities/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:py_rules.bzl", "py_library", "py_test") +load("//bazel:py_rules.bzl", "py_library") package(default_visibility = ["//visibility:public"]) @@ -6,20 +6,9 @@ py_library( name = "entities_lib", srcs = [ "model_monitor_config.py", - "output_score_type.py", ], deps = [ "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:type_hints", ], ) - -py_test( - name = "output_score_type_test", - srcs = [ - "output_score_type_test.py", - ], - deps = [ - ":entities_lib", - ], -) diff --git a/snowflake/ml/monitoring/entities/model_monitor_config.py b/snowflake/ml/monitoring/entities/model_monitor_config.py index d030fe16..9475517e 100644 --- a/snowflake/ml/monitoring/entities/model_monitor_config.py +++ b/snowflake/ml/monitoring/entities/model_monitor_config.py @@ -6,23 +6,49 @@ @dataclass class ModelMonitorSourceConfig: + """Configuration for the source of data to be monitored.""" + source: str + """Name of table or view containing monitoring data.""" + timestamp_column: str + """Name of column in the source containing timestamp.""" + id_columns: List[str] + """List of columns in the source containing unique identifiers.""" + prediction_score_columns: Optional[List[str]] = None + """List of columns in the source containing prediction scores. + Can be regression scores for regression models and probability scores for classification models.""" + prediction_class_columns: Optional[List[str]] = None + """List of columns in the source containing prediction classes for classification models.""" + actual_score_columns: Optional[List[str]] = None + """List of columns in the source containing actual scores.""" + actual_class_columns: Optional[List[str]] = None + """List of columns in the source containing actual classes for classification models.""" + baseline: Optional[str] = None + """Name of table containing the baseline data.""" @dataclass class ModelMonitorConfig: + """Configuration for the Model Monitor.""" + model_version: model_version_impl.ModelVersion + """Model version to monitor.""" - # Python model function name model_function_name: str + """Function name in the model to monitor.""" + background_compute_warehouse_name: str - # TODO: Add support for pythonic notion of time. + """Name of the warehouse to use for background compute.""" + refresh_interval: str = "1 hour" + """Interval at which to refresh the monitoring data.""" + aggregation_window: str = "1 day" + """Window for aggregating monitoring data.""" diff --git a/snowflake/ml/monitoring/entities/output_score_type.py b/snowflake/ml/monitoring/entities/output_score_type.py deleted file mode 100644 index a34eca24..00000000 --- a/snowflake/ml/monitoring/entities/output_score_type.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import List, Mapping - -from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import type_hints -from snowflake.snowpark import types - -# Accepted data types for each OutputScoreType. -REGRESSION_DATA_TYPES = ( - types.ByteType, - types.ShortType, - types.IntegerType, - types.LongType, - types.FloatType, - types.DoubleType, - types.DecimalType, -) -CLASSIFICATION_DATA_TYPES = ( - types.ByteType, - types.ShortType, - types.IntegerType, - types.BooleanType, - types.BinaryType, -) -PROBITS_DATA_TYPES = ( - types.ByteType, - types.ShortType, - types.IntegerType, - types.LongType, - types.FloatType, - types.DoubleType, - types.DecimalType, -) - - -# OutputScoreType enum -class OutputScoreType(Enum): - UNKNOWN = "UNKNOWN" - REGRESSION = "REGRESSION" - CLASSIFICATION = "CLASSIFICATION" - PROBITS = "PROBITS" - - @classmethod - def deduce_score_type( - cls, - table_schema: Mapping[str, types.DataType], - prediction_columns: List[sql_identifier.SqlIdentifier], - task: type_hints.Task, - ) -> OutputScoreType: - """Find the score type for monitoring given a table schema and the task. - - Args: - table_schema: Dictionary of column names and types in the source table. - prediction_columns: List of prediction columns. - task: Enum value for the task of the model. - - Returns: - Enum value for the score type, informing monitoring table set up. - - Raises: - ValueError: If prediction type fails to align with task. - """ - # Already validated we have just one prediction column type - prediction_column_type = {table_schema[column_name] for column_name in prediction_columns}.pop() - - if task == type_hints.Task.TABULAR_REGRESSION: - if isinstance(prediction_column_type, REGRESSION_DATA_TYPES): - return OutputScoreType.REGRESSION - else: - raise ValueError( - f"Expected prediction column type to be one of {REGRESSION_DATA_TYPES} " - f"for REGRESSION task. Found: {prediction_column_type}." - ) - - elif task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION: - if isinstance(prediction_column_type, CLASSIFICATION_DATA_TYPES): - return OutputScoreType.CLASSIFICATION - elif isinstance(prediction_column_type, PROBITS_DATA_TYPES): - return OutputScoreType.PROBITS - else: - raise ValueError( - f"Expected prediction column type to be one of {CLASSIFICATION_DATA_TYPES} " - f"or one of {PROBITS_DATA_TYPES} for CLASSIFICATION task. " - f"Found: {prediction_column_type}." - ) - - else: - raise ValueError(f"Received unsupported task for model monitoring: {task}.") diff --git a/snowflake/ml/monitoring/entities/output_score_type_test.py b/snowflake/ml/monitoring/entities/output_score_type_test.py deleted file mode 100644 index 13f1c54e..00000000 --- a/snowflake/ml/monitoring/entities/output_score_type_test.py +++ /dev/null @@ -1,93 +0,0 @@ -import re -from typing import List, Mapping, Tuple - -from absl.testing import absltest - -from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import type_hints -from snowflake.ml.monitoring.entities import output_score_type -from snowflake.snowpark import types - -DEDUCE_SCORE_TYPE_ACCEPTED_COMBINATIONS: List[ - Tuple[ - Mapping[str, types.DataType], - List[sql_identifier.SqlIdentifier], - type_hints.Task, - output_score_type.OutputScoreType, - ] -] = [ - ( - {"PREDICTION1": types.FloatType()}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.TABULAR_REGRESSION, - output_score_type.OutputScoreType.REGRESSION, - ), - ( - {"PREDICTION1": types.DecimalType(38, 1)}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - output_score_type.OutputScoreType.PROBITS, - ), - ( - {"PREDICTION1": types.BinaryType()}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - output_score_type.OutputScoreType.CLASSIFICATION, - ), -] - - -DEDUCE_SCORE_TYPE_FAILURE_COMBINATIONS: List[ - Tuple[Mapping[str, types.DataType], List[sql_identifier.SqlIdentifier], type_hints.Task, str] -] = [ - ( - {"PREDICTION1": types.BinaryType()}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.TABULAR_REGRESSION, - f"Expected prediction column type to be one of {output_score_type.REGRESSION_DATA_TYPES} " - f"for REGRESSION task. Found: {types.BinaryType()}.", - ), - ( - {"PREDICTION1": types.StringType()}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.TABULAR_BINARY_CLASSIFICATION, - f"Expected prediction column type to be one of {output_score_type.CLASSIFICATION_DATA_TYPES} " - f"or one of {output_score_type.PROBITS_DATA_TYPES} for CLASSIFICATION task. " - f"Found: {types.StringType()}.", - ), - ( - {"PREDICTION1": types.BinaryType()}, - [sql_identifier.SqlIdentifier("PREDICTION1")], - type_hints.Task.UNKNOWN, - f"Received unsupported task for model monitoring: {type_hints.Task.UNKNOWN}.", - ), -] - - -class OutputScoreTypeTest(absltest.TestCase): - def test_deduce_score_type(self) -> None: - # Success cases - for ( - table_schema, - prediction_column_names, - task, - expected_score_type, - ) in DEDUCE_SCORE_TYPE_ACCEPTED_COMBINATIONS: - actual_score_type = output_score_type.OutputScoreType.deduce_score_type( - table_schema, prediction_column_names, task - ) - self.assertEqual(actual_score_type, expected_score_type) - - # Failure cases - for ( - table_schema, - prediction_column_names, - task, - expected_error, - ) in DEDUCE_SCORE_TYPE_FAILURE_COMBINATIONS: - with self.assertRaisesRegex(ValueError, re.escape(expected_error)): - output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_column_names, task) - - -if __name__ == "__main__": - absltest.main() diff --git a/snowflake/ml/monitoring/model_monitor.py b/snowflake/ml/monitoring/model_monitor.py index 869280eb..4c8b4860 100644 --- a/snowflake/ml/monitoring/model_monitor.py +++ b/snowflake/ml/monitoring/model_monitor.py @@ -1,5 +1,7 @@ +from snowflake import snowpark from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.monitoring import model_monitor_version from snowflake.ml.monitoring._client import model_monitor_sql_client @@ -9,13 +11,8 @@ class ModelMonitor: name: sql_identifier.SqlIdentifier _model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient - statement_params = telemetry.get_statement_params( - telemetry.TelemetryProject.MLOPS.value, - telemetry.TelemetrySubProject.MONITORING.value, - ) - def __init__(self) -> None: - raise RuntimeError("ModelMonitor's initializer is not meant to be used.") + raise RuntimeError("Model Monitor's initializer is not meant to be used.") @classmethod def _ref( @@ -28,10 +25,28 @@ def _ref( self._model_monitor_client = model_monitor_client return self + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) def suspend(self) -> None: - """Suspend pipeline for ModelMonitor""" - self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params) - + """Suspend the Model Monitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + self._model_monitor_client.suspend_monitor(self.name, statement_params=statement_params) + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) def resume(self) -> None: - """Resume pipeline for ModelMonitor""" - self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params) + """Resume the Model Monitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params) diff --git a/snowflake/ml/registry/registry.py b/snowflake/ml/registry/registry.py index f920dc9d..89cbad13 100644 --- a/snowflake/ml/registry/registry.py +++ b/snowflake/ml/registry/registry.py @@ -388,15 +388,15 @@ def add_monitor( source_config: model_monitor_config.ModelMonitorSourceConfig, model_monitor_config: model_monitor_config.ModelMonitorConfig, ) -> model_monitor.ModelMonitor: - """Add a Model Monitor to the Registry + """Add a Model Monitor to the Registry. Args: - name: Name of Model Monitor to create - source_config: Configuration options of table for ModelMonitor. - model_monitor_config: Configuration options of ModelMonitor. + name: Name of Model Monitor to create. + source_config: Configuration options of table for Model Monitor. + model_monitor_config: Configuration options of Model Monitor. Returns: - The newly added ModelMonitor object. + The newly added Model Monitor object. Raises: ValueError: If monitoring is not enabled in the Registry. @@ -407,16 +407,16 @@ def add_monitor( @overload def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor: - """Get a Model Monitor on a ModelVersion from the Registry + """Get a Model Monitor on a Model Version from the Registry. Args: - model_version: ModelVersion for which to retrieve the ModelMonitor. + model_version: Model Version for which to retrieve the Model Monitor. """ ... @overload def get_monitor(self, name: str) -> model_monitor.ModelMonitor: - """Get a Model Monitor from the Registry + """Get a Model Monitor by name from the Registry. Args: name: Name of Model Monitor to retrieve. @@ -431,14 +431,14 @@ def get_monitor(self, name: str) -> model_monitor.ModelMonitor: def get_monitor( self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None ) -> model_monitor.ModelMonitor: - """Get a Model Monitor from the Registry + """Get a Model Monitor from the Registry. Args: name: Name of Model Monitor to retrieve. - model_version: ModelVersion for which to retrieve the ModelMonitor. + model_version: Model Version for which to retrieve the Model Monitor. Returns: - The fetched ModelMonitor. + The fetched Model Monitor. Raises: ValueError: If monitoring is not enabled in the Registry. @@ -476,7 +476,7 @@ def show_model_monitors(self) -> List[snowpark.Row]: ) @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION) def delete_monitor(self, name: str) -> None: - """Delete a Model Monitor from the Registry + """Delete a Model Monitor by name from the Registry. Args: name: Name of the Model Monitor to delete. diff --git a/snowflake/ml/utils/BUILD.bazel b/snowflake/ml/utils/BUILD.bazel index 8f7a0452..95632b41 100644 --- a/snowflake/ml/utils/BUILD.bazel +++ b/snowflake/ml/utils/BUILD.bazel @@ -2,6 +2,22 @@ load("//bazel:py_rules.bzl", "py_library", "py_package", "py_test") package(default_visibility = ["//visibility:public"]) +py_library( + name = "authentication", + srcs = ["authentication.py"], + deps = [ + "//snowflake/ml/_internal/utils:jwt_generator", + ], +) + +py_test( + name = "authentication_test", + srcs = ["authentication_test.py"], + deps = [ + ":authentication", + ], +) + py_library( name = "connection_params", srcs = ["connection_params.py"], @@ -46,8 +62,10 @@ py_package( name = "utils_pkg", packages = ["snowflake.ml"], deps = [ + ":authentication", ":connection_params", ":sparse", + ":sql_client", "//snowflake/ml/_internal/utils:snowflake_env", # Mitigate BuildSnowML failure ], ) diff --git a/snowflake/ml/utils/authentication.py b/snowflake/ml/utils/authentication.py new file mode 100644 index 00000000..b560900e --- /dev/null +++ b/snowflake/ml/utils/authentication.py @@ -0,0 +1,75 @@ +import http +import logging +from datetime import timedelta +from typing import Dict, Optional + +import requests +from cryptography.hazmat.primitives.asymmetric import types +from requests import auth + +from snowflake.ml._internal.utils import jwt_generator + +logger = logging.getLogger(__name__) +_JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {} + + +def get_jwt_token_generator( + account: str, + user: str, + private_key: types.PRIVATE_KEY_TYPES, + lifetime: Optional[timedelta] = None, + renewal_delay: Optional[timedelta] = None, +) -> jwt_generator.JWTGenerator: + return jwt_generator.JWTGenerator(account, user, private_key, lifetime=lifetime, renewal_delay=renewal_delay) + + +def _get_snowflake_token_by_jwt( + jwt_token_generator: jwt_generator.JWTGenerator, + account: Optional[str] = None, + role: Optional[str] = None, + endpoint: Optional[str] = None, + snowflake_account_url: Optional[str] = None, +) -> str: + scope_role = f"session:role:{role}" if role is not None else None + scope = " ".join(filter(None, [scope_role, endpoint])) + data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": scope or None, + "assertion": jwt_token_generator.get_token(), + } + account = account or jwt_token_generator.account + url = f"https://{account}.snowflakecomputing.com/oauth/token" + if snowflake_account_url: + url = f"{snowflake_account_url}/oauth/token" + + cache_key = hash(frozenset(data.items())) + if url in _JWT_TOKEN_CACHE: + if cache_key in _JWT_TOKEN_CACHE[url]: + return _JWT_TOKEN_CACHE[url][cache_key] + else: + _JWT_TOKEN_CACHE[url] = {} + + response = requests.post(url, data=data) + if response.status_code != http.HTTPStatus.OK: + raise RuntimeError(f"Failed to get snowflake token: {response.status_code} {response.content!r}") + auth_token = response.text + _JWT_TOKEN_CACHE[url][cache_key] = auth_token + return auth_token + + +class SnowflakeJWTTokenAuth(auth.AuthBase): + def __init__( + self, + jwt_token_generator: jwt_generator.JWTGenerator, + account: Optional[str] = None, + role: Optional[str] = None, + endpoint: Optional[str] = None, + snowflake_account_url: Optional[str] = None, + ) -> None: + self.snowflake_token = _get_snowflake_token_by_jwt( + jwt_token_generator, account, role, endpoint, snowflake_account_url + ) + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + r.headers["Authorization"] = f'Snowflake Token="{self.snowflake_token}"' + return r diff --git a/snowflake/ml/utils/authentication_test.py b/snowflake/ml/utils/authentication_test.py new file mode 100644 index 00000000..dca9040b --- /dev/null +++ b/snowflake/ml/utils/authentication_test.py @@ -0,0 +1,199 @@ +from unittest import mock + +import requests +from absl.testing import absltest + +from snowflake.ml.utils import authentication + + +class AuthenticationTest(absltest.TestCase): + def setUp(self) -> None: + self.m_jwt_generator = mock.MagicMock() + self.m_jwt_generator.get_token.return_value = "jwt_token" + authentication._JWT_TOKEN_CACHE = {} + + def test_get_jwt_token_default_account(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator) + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_error(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 404 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + with self.assertRaisesRegex(RuntimeError, "Failed to get snowflake token"): + authentication._get_snowflake_token_by_jwt(self.m_jwt_generator) + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_overridden_account(self) -> None: + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator, account="account") + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_role(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator, role="role") + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": "session:role:role", + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_endpoint(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator, endpoint="endpoint") + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": "endpoint", + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_role_and_endpoint(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator, role="role", endpoint="endpoint") + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": "session:role:role endpoint", + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_account_url(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt( + self.m_jwt_generator, snowflake_account_url="https://account.url" + ) + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.url/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + + def test_get_jwt_token_cache_hit(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator) + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator) + self.assertEqual(token, "auth_token") + m_post.assert_not_called() + + def test_get_jwt_token_cache_miss(self) -> None: + self.m_jwt_generator.account = "account" + m_response = mock.MagicMock() + m_response.status_code = 200 + m_response.text = "auth_token" + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator) + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt(self.m_jwt_generator, role="role") + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.snowflakecomputing.com/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": "session:role:role", + "assertion": "jwt_token", + }, + ) + with mock.patch.object(requests, "post", return_value=m_response) as m_post: + token = authentication._get_snowflake_token_by_jwt( + self.m_jwt_generator, snowflake_account_url="https://account.url" + ) + self.assertEqual(token, "auth_token") + m_post.assert_called_once_with( + "https://account.url/oauth/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "scope": None, + "assertion": "jwt_token", + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/version.bzl b/snowflake/ml/version.bzl index 20d98634..b604d743 100644 --- a/snowflake/ml/version.bzl +++ b/snowflake/ml/version.bzl @@ -1,2 +1,2 @@ # This is parsed by regex in conda reciper meta file. Make sure not to break it. -VERSION = "1.7.1" +VERSION = "1.7.2" diff --git a/tests/integ/snowflake/cortex/complete_test.py b/tests/integ/snowflake/cortex/complete_test.py index e808cddc..b0be849c 100644 --- a/tests/integ/snowflake/cortex/complete_test.py +++ b/tests/integ/snowflake/cortex/complete_test.py @@ -91,6 +91,10 @@ def test_immediate_mode_empty_options(self) -> None: self.assertTrue(res) +@absltest.skipUnless( + test_env_utils.get_current_snowflake_cloud_type() == snowflake_env.SnowflakeCloudType.AWS, + "Complete SQL only available in AWS", +) class CompleteRestTest(absltest.TestCase): def setUp(self) -> None: self._session = Session.builder.configs(connection_params.SnowflakeLoginOptions()).create() diff --git a/tests/integ/snowflake/ml/registry/services/BUILD.bazel b/tests/integ/snowflake/ml/registry/services/BUILD.bazel index 663dbde4..34465e20 100644 --- a/tests/integ/snowflake/ml/registry/services/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/services/BUILD.bazel @@ -17,6 +17,7 @@ py_library( "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_client/model:model_version_impl", "//snowflake/ml/registry", + "//snowflake/ml/utils:authentication", "//snowflake/ml/utils:connection_params", "//tests/integ/snowflake/ml/test_utils:common_test_base", "//tests/integ/snowflake/ml/test_utils:db_manager", @@ -98,7 +99,6 @@ py_test( name = "registry_model_deployment_test", timeout = "eternal", srcs = ["registry_model_deployment_test.py"], - shard_count = 2, deps = [ ":registry_model_deployment_test_base", ], diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py index af1cd966..4a832a6d 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py @@ -1,3 +1,4 @@ +import http import inspect import logging import os @@ -6,17 +7,29 @@ import uuid from typing import Any, Callable, Dict, List, Optional, Tuple, cast +import numpy as np +import pandas as pd import pytest +import requests +import retrying import yaml from absl.testing import absltest +from cryptography.hazmat import backends +from cryptography.hazmat.primitives import serialization from packaging import version from snowflake.ml._internal import file_utils -from snowflake.ml._internal.utils import snowflake_env, sql_identifier -from snowflake.ml.model import ModelVersion, type_hints as model_types +from snowflake.ml._internal.utils import ( + identifier, + jwt_generator, + snowflake_env, + sql_identifier, +) +from snowflake.ml.model import ModelVersion, model_signature, type_hints as model_types from snowflake.ml.model._client.ops import service_ops from snowflake.ml.model._client.service import model_deployment_spec from snowflake.ml.registry import registry +from snowflake.ml.utils import authentication from snowflake.snowpark import row from snowflake.snowpark._internal import utils as snowpark_utils from tests.integ.snowflake.ml.test_utils import ( @@ -34,7 +47,6 @@ class RegistryModelDeploymentTestBase(common_test_base.CommonTestBase): _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" _TEST_GPU_COMPUTE_POOL = "REGTEST_INFERENCE_GPU_POOL" - _SPCS_EAI = "SPCS_EGRESS_ACCESS_INTEGRATION" _TEST_SPCS_WH = "REGTEST_ML_SMALL" BUILDER_IMAGE_PATH = os.getenv("BUILDER_IMAGE_PATH", None) @@ -45,6 +57,11 @@ def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing.""" super().setUp() + with open(self.session._conn._lower_case_parameters["private_key_path"], "rb") as f: + self.private_key = serialization.load_pem_private_key( + f.read(), password=None, backend=backends.default_backend() + ) + self._run_id = uuid.uuid4().hex[:2] self._test_db = db_manager.TestObjectNameGenerator.get_snowml_test_object_name(self._run_id, "db").upper() self._test_schema = "PUBLIC" @@ -104,11 +121,11 @@ def _deploy_model_with_image_override( service_schema_name=schema_name, service_name=sql_identifier.SqlIdentifier(service_name), image_build_compute_pool_name=build_compute_pool, - service_compute_pool_name=service_compute_pool, + service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool), image_repo_database_name=database_name, image_repo_schema_name=schema_name, image_repo_name=image_repo_name, - ingress_enabled=False, + ingress_enabled=True, max_instances=max_instances, num_workers=num_workers, max_batch_rows=max_batch_rows, @@ -116,7 +133,6 @@ def _deploy_model_with_image_override( memory=None, gpu=gpu_requests, force_rebuild=force_rebuild, - external_access_integrations=[sql_identifier.SqlIdentifier(self._SPCS_EAI)], ) with (mv._service_ops.workspace_path / deploy_spec_file_rel_path).open("r", encoding="utf-8") as f: @@ -166,7 +182,6 @@ def _deploy_model_with_image_override( res = cast(str, cast(List[row.Row], async_job.result())[0][0]) logging.info(f"Inference service {service_name} deployment complete: {res}") - return res def _test_registry_model_deployment( self, @@ -241,11 +256,77 @@ def _test_registry_model_deployment( num_workers=num_workers, max_instances=max_instances, max_batch_rows=max_batch_rows, - build_external_access_integrations=[self._SPCS_EAI], + ingress_enabled=True, ) for target_method, (test_input, check_func) in prediction_assert_fns.items(): res = mv.run(test_input, function_name=target_method, service_name=service_name) check_func(res) + endpoint = RegistryModelDeploymentTestBase._ensure_ingress_url(mv) + jwt_token_generator = self._get_jwt_token_generator() + + for target_method, (test_input, check_func) in prediction_assert_fns.items(): + res_df = self._inference_using_rest_api( + test_input, endpoint=endpoint, jwt_token_generator=jwt_token_generator, target_method=target_method + ) + check_func(res_df) + return mv + + @staticmethod + def retry_if_result_status_retriable(result: requests.Response) -> bool: + if result.status_code in [ + http.HTTPStatus.SERVICE_UNAVAILABLE, + http.HTTPStatus.TOO_MANY_REQUESTS, + http.HTTPStatus.GATEWAY_TIMEOUT, + ]: + return True + return False + + @staticmethod + def _ensure_ingress_url(mv: ModelVersion) -> str: + while True: + endpoint = mv.list_services().loc[0, "inference_endpoint"] + if endpoint is not None: + break + time.sleep(10) + return endpoint + + def _get_jwt_token_generator(self) -> jwt_generator.JWTGenerator: + account = identifier.get_unescaped_names(self.session.get_current_account()) + user = identifier.get_unescaped_names(self.session.get_current_user()) + if not account or not user: + raise ValueError("Account and user must be set.") + + return authentication.get_jwt_token_generator( + account, + user, + self.private_key, + ) + + def _inference_using_rest_api( + self, + test_input: pd.DataFrame, + *, + endpoint: str, + jwt_token_generator: jwt_generator.JWTGenerator, + target_method: str, + ) -> pd.DataFrame: + test_input_arr = model_signature._convert_local_data_to_df(test_input).values + test_input_arr = np.column_stack([range(test_input_arr.shape[0]), test_input_arr]) + res = retrying.retry( + wait_exponential_multiplier=100, + wait_exponential_max=4000, + retry_on_result=RegistryModelDeploymentTestBase.retry_if_result_status_retriable, + )(requests.post)( + f"https://{endpoint}/{target_method.replace('_', '-')}", + json={"data": test_input_arr.tolist()}, + auth=authentication.SnowflakeJWTTokenAuth( + jwt_token_generator=jwt_token_generator, + role=identifier.get_unescaped_names(self.session.get_current_role()), + endpoint=endpoint, + ), + ) + res.raise_for_status() + return pd.DataFrame([x[1] for x in res.json()["data"]]) diff --git a/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py b/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py index e4b3fc94..256aa67a 100644 --- a/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py +++ b/tests/integ/snowflake/ml/test_utils/spcs_integ_test_base.py @@ -16,7 +16,6 @@ class SpcsIntegTestBase(absltest.TestCase): _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" _TEST_GPU_COMPUTE_POOL = "REGTEST_INFERENCE_GPU_POOL" - _SPCS_EAIS = ["SPCS_EGRESS_ACCESS_INTEGRATION"] def setUp(self) -> None: """Creates Snowpark and Snowflake environments for testing."""