diff --git a/clients/python/README.md b/clients/python/README.md index e542e8e2f..6af08cf10 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -14,12 +14,12 @@ registry = ModelRegistry(server_address="server-address", port=9090, author="aut model = registry.register_model( "my-model", # model name - "s3://path/to/model", # model URI + "https://storage-place.my-company.com", # model URI version="2.0.0", description="lorem ipsum", model_format_name="onnx", model_format_version="1", - storage_key="aws-connection-path", + storage_key="my-data-connection", storage_path="path/to/model", metadata={ # can be one of the following types @@ -37,10 +37,33 @@ version = registry.get_model_version("my-model", "v2.0") experiment = registry.get_model_artifact("my-model", "v2.0") ``` -### Default values for metadata +### Importing from S3 -If not supplied, `metadata` values defaults to a predefined set of conventional values. -Reference the technical documentation in the pydoc of the client. +When registering models stored on S3-compatible object storage, you should use `utils.s3_uri_from` to build an +unambiguous URI for your artifact. + +```py +from model_registry import ModelRegistry, utils + +registry = ModelRegistry(server_address="server-address", port=9090, author="author") + +model = registry.register_model( + "my-model", # model name + uri=utils.s3_uri_from("path/to/model", "my-bucket"), + version="2.0.0", + description="lorem ipsum", + model_format_name="onnx", + model_format_version="1", + storage_key="my-data-connection", + metadata={ + # can be one of the following types + "int_key": 1, + "bool_key": False, + "float_key": 3.14, + "str_key": "str_value", + } +) +``` ### Importing from Hugging Face Hub diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 826f061f6..5a3074e95 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -1,7 +1,7 @@ """Standard client for the model registry.""" + from __future__ import annotations -import os from typing import get_args from warnings import warn @@ -75,16 +75,22 @@ def register_model( model_format_name: str, model_format_version: str, version: str, - author: str | None = None, - description: str | None = None, storage_key: str | None = None, storage_path: str | None = None, service_account_name: str | None = None, + author: str | None = None, + description: str | None = None, metadata: dict[str, ScalarType] | None = None, ) -> RegisteredModel: """Register a model. - Either `storage_key` and `storage_path`, or `service_account_name` must be provided. + This registers a model in the model registry. The model is not downloaded, and has to be stored prior to + registration. + + Most models can be registered using their URI, along with optional connection-specific parameters, `storage_key` + and `storage_path` or, simply a `service_account_name`. + URI builder utilities are recommended when referring to specialized storage; for example `utils.s3_uri_from` + helper when using S3 object storage data connections. Args: name: Name of the model. @@ -110,7 +116,7 @@ def register_model( version, author or self._author, description=description, - metadata=metadata or self.default_metadata(), + metadata=metadata or {}, ) self._register_model_artifact( mv, @@ -124,19 +130,6 @@ def register_model( return rm - def default_metadata(self) -> dict[str, ScalarType]: - """Default metadata valorisations. - - When not explicitly supplied by the end users, these valorisations will be used - by default. - - Returns: - default metadata valorisations. - """ - return { - key: os.environ[key] for key in ["AWS_S3_ENDPOINT", "AWS_S3_BUCKET", "AWS_DEFAULT_REGION"] if key in os.environ - } - def register_hf_model( self, repo: str, @@ -202,7 +195,6 @@ def register_hf_model( model_author = author source_uri = hf_hub_url(repo, path, revision=git_ref) metadata = { - **self.default_metadata(), "repo": repo, "source_uri": source_uri, "model_origin": "huggingface_hub", diff --git a/clients/python/src/model_registry/_utils.py b/clients/python/src/model_registry/_utils.py new file mode 100644 index 000000000..b2a32cb86 --- /dev/null +++ b/clients/python/src/model_registry/_utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import functools +import inspect +from collections.abc import Sequence +from typing import Any, Callable, TypeVar + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +# copied from https://github.com/openai/openai-python +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901 + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: + ... + + + @overload + def foo(*, b: bool) -> str: + ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: + ... + ``` + """ + + def inner(func: CallableT) -> CallableT: # noqa: C901 + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + msg = f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + raise TypeError(msg) from None + + for key in kwargs: + given_params.add(key) + + for variant in variants: + matches = all(param in given_params for param in variant) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + [ + "(" + + human_join([quote(arg) for arg in variant], final="and") + + ")" + for variant in variants + ] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner diff --git a/clients/python/src/model_registry/exceptions.py b/clients/python/src/model_registry/exceptions.py index bac601d24..7a52ddacf 100644 --- a/clients/python/src/model_registry/exceptions.py +++ b/clients/python/src/model_registry/exceptions.py @@ -5,6 +5,10 @@ class StoreException(Exception): """Storage related error.""" +class MissingMetadata(Exception): + """Not enough metadata to complete operation.""" + + class UnsupportedTypeException(StoreException): """Raised when an unsupported type is encountered.""" diff --git a/clients/python/src/model_registry/utils.py b/clients/python/src/model_registry/utils.py new file mode 100644 index 000000000..e60dcf5dd --- /dev/null +++ b/clients/python/src/model_registry/utils.py @@ -0,0 +1,92 @@ +"""Utilities for the model registry.""" + +from __future__ import annotations + +import os + +from typing_extensions import overload + +from ._utils import required_args +from .exceptions import MissingMetadata + + +@overload +def s3_uri_from( + path: str, +) -> str: ... + + +@overload +def s3_uri_from( + path: str, + bucket: str, +) -> str: ... + + +@overload +def s3_uri_from( + path: str, + bucket: str, + *, + endpoint: str, + region: str, +) -> str: ... + + +@required_args( + (), + ( # pre-configured env + "bucket", + ), + ( # custom env or non-default bucket + "bucket", + "endpoint", + "region", + ), +) +def s3_uri_from( + path: str, + bucket: str | None = None, + *, + endpoint: str | None = None, + region: str | None = None, +) -> str: + """Build an S3 URI. + + This helper function builds an S3 URI from a path and a bucket name, assuming you have a configured environment + with a default bucket, endpoint, and region set. + If you don't, you must provide all three optional arguments. + That is also the case for custom environments, where the default bucket is not set, or if you want to use a + different bucket. + + Args: + path: Storage path. + bucket: Name of the S3 bucket. Defaults to AWS_S3_BUCKET. + endpoint: Endpoint of the S3 bucket. Defaults to AWS_S3_ENDPOINT. + region: Region of the S3 bucket. Defaults to AWS_DEFAULT_REGION. + + Returns: + S3 URI. + """ + default_bucket = os.environ.get("AWS_S3_BUCKET") + if not bucket: + if not default_bucket: + msg = "Custom environment requires all arguments" + raise MissingMetadata(msg) + bucket = default_bucket + elif (not default_bucket or default_bucket != bucket) and not endpoint: + msg = ( + "bucket_endpoint and bucket_region must be provided for non-default bucket" + ) + raise MissingMetadata(msg) + + endpoint = endpoint or os.getenv("AWS_S3_ENDPOINT") + region = region or os.getenv("AWS_DEFAULT_REGION") + + if not (endpoint and region): + msg = "Missing environment variables: bucket_endpoint and bucket_region are required" + raise MissingMetadata(msg) + + # https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls + # FIXME: is this safe? + return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 9e33b07b6..e498b892a 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -1,7 +1,7 @@ import os import pytest -from model_registry import ModelRegistry +from model_registry import ModelRegistry, utils from model_registry.core import ModelRegistryAPIClient from model_registry.exceptions import StoreException @@ -31,6 +31,27 @@ def test_register_new(mr_client: ModelRegistry): assert mr_api.get_model_artifact_by_params(mv.id) is not None +def test_register_new_using_s3_uri_builder(mr_client: ModelRegistry): + name = "test_model" + version = "1.0.0" + uri = utils.s3_uri_from( + "storage/path", "my-bucket", endpoint="my-endpoint", region="my-region" + ) + rm = mr_client.register_model( + name, + uri, + model_format_name="test_format", + model_format_version="test_version", + version=version, + ) + assert rm.id is not None + + mr_api = mr_client._api + assert (mv := mr_api.get_model_version_by_params(rm.id, version)) is not None + assert (ma := mr_api.get_model_artifact_by_params(mv.id)) is not None + assert ma.uri == uri + + def test_register_existing_version(mr_client: ModelRegistry): params = { "name": "test_model", @@ -56,7 +77,7 @@ def test_get(mr_client: ModelRegistry): model_format_name="test_format", model_format_version="test_version", version=version, - metadata=metadata + metadata=metadata, ) assert (_rm := mr_client.get_registered_model(name)) @@ -73,28 +94,6 @@ def test_get(mr_client: ModelRegistry): assert ma.id == _ma.id -def test_default_md(mr_client: ModelRegistry): - name = "test_model" - version = "1.0.0" - env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} - for k, v in env_values.items(): - os.environ[k] = v - - assert mr_client.register_model( - name, - "s3", - model_format_name="test_format", - model_format_version="test_version", - version=version, - # ensure leave empty metadata - ) - assert (mv := mr_client.get_model_version(name, version)) - assert mv.metadata == env_values - - for k in env_values: - os.environ.pop(k) - - def test_hf_import(mr_client: ModelRegistry): pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" @@ -113,19 +112,25 @@ def test_hf_import(mr_client: ModelRegistry): assert mv.author == author assert mv.metadata["model_author"] == author assert mv.metadata["model_origin"] == "huggingface_hub" - assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert ( + mv.metadata["source_uri"] + == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + ) assert mv.metadata["repo"] == name assert mr_client.get_model_artifact(name, version) def test_hf_import_default_env(mr_client: ModelRegistry): - """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata - """ + """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata""" pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" version = "1.2.3" author = "test author" - env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"} + env_values = { + "AWS_S3_ENDPOINT": "value1", + "AWS_S3_BUCKET": "value2", + "AWS_DEFAULT_REGION": "value3", + } for k, v in env_values.items(): os.environ[k] = v @@ -140,7 +145,10 @@ def test_hf_import_default_env(mr_client: ModelRegistry): assert (mv := mr_client.get_model_version(name, version)) assert mv.metadata["model_author"] == author assert mv.metadata["model_origin"] == "huggingface_hub" - assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + assert ( + mv.metadata["source_uri"] + == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx" + ) assert mv.metadata["repo"] == name assert mr_client.get_model_artifact(name, version) diff --git a/clients/python/tests/test_utils.py b/clients/python/tests/test_utils.py new file mode 100644 index 000000000..a29f04f3a --- /dev/null +++ b/clients/python/tests/test_utils.py @@ -0,0 +1,72 @@ +import os + +import pytest +from model_registry.exceptions import MissingMetadata +from model_registry.utils import s3_uri_from + + +def test_s3_uri_builder(): + s3_uri = s3_uri_from( + "test-path", + "test-bucket", + endpoint="test-endpoint", + region="test-region", + ) + assert ( + s3_uri + == "s3://test-bucket/test-path?endpoint=test-endpoint&defaultRegion=test-region" + ) + + +def test_s3_uri_builder_without_env(): + os.environ.pop("AWS_S3_BUCKET", None) + os.environ.pop("AWS_S3_ENDPOINT", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "custom environment" in str(e.value).lower() + + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + "test-bucket", + ) + assert "non-default bucket" in str(e.value).lower() + + +def test_s3_uri_builder_with_only_default_bucket_env(): + os.environ["AWS_S3_BUCKET"] = "test-bucket" + os.environ.pop("AWS_S3_ENDPOINT", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "missing environment variables" in str(e.value).lower() + + +def test_s3_uri_builder_with_other_default_variables(): + os.environ.pop("AWS_S3_BUCKET", None) + os.environ["AWS_S3_ENDPOINT"] = "test-endpoint" + os.environ["AWS_DEFAULT_REGION"] = "test-region" + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + ) + assert "custom environment" in str(e.value).lower() + + with pytest.raises(MissingMetadata) as e: + s3_uri_from( + "test-path", + "test-bucket", + ) + assert "non-default bucket" in str(e.value).lower() + + +def test_s3_uri_builder_with_complete_env(): + os.environ["AWS_S3_BUCKET"] = "test-bucket" + os.environ["AWS_S3_ENDPOINT"] = "test-endpoint" + os.environ["AWS_DEFAULT_REGION"] = "test-region" + assert s3_uri_from("test-path") == s3_uri_from("test-path", "test-bucket")