From 4c66394e4d87dee1a891fa6e994b465935668521 Mon Sep 17 00:00:00 2001 From: Isabella do Amaral Date: Thu, 2 May 2024 13:46:52 -0300 Subject: [PATCH] py: provide API builders for secure and insecure connections Signed-off-by: Isabella do Amaral --- clients/python/src/model_registry/_client.py | 33 +++++- clients/python/src/model_registry/core.py | 117 ++++++++++--------- clients/python/tests/conftest.py | 2 +- clients/python/tests/test_core.py | 66 +++++------ test/robot/ModelRegistry.py | 2 +- 5 files changed, 128 insertions(+), 92 deletions(-) diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index c38f1cc9a..e26209c62 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -2,6 +2,8 @@ from __future__ import annotations +import os +from pathlib import Path from typing import get_args from warnings import warn @@ -36,7 +38,36 @@ def __init__( """ # TODO: get args from env self._author = author - self._api = ModelRegistryAPIClient(server_address, port, user_token, custom_ca) + + if not user_token: + # /var/run/secrets/kubernetes.io/serviceaccount/token + sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH") + if sa_token: + user_token = Path(sa_token).read_bytes() + else: + warn("User access token is missing", stacklevel=2) + + root_certs = None + if not custom_ca: + ca_cert = os.environ.get("CERT") + if ca_cert: + root_certs = Path(ca_cert).read_bytes() + elif port == 443: + warn( + "missing CA certificate, which is required for a secure connection", + stacklevel=2, + ) + else: + root_certs = custom_ca + + if root_certs: + self._api = ModelRegistryAPIClient.secure_connection( + server_address, port, user_token, custom_ca + ) + else: + self._api = ModelRegistryAPIClient.insecure_connection( + server_address, port, user_token + ) def _register_model(self, name: str) -> RegisteredModel: if rm := self._api.get_registered_model_by_params(name): diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 4d9579aeb..6909a3004 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from dataclasses import dataclass from pathlib import Path from warnings import warn @@ -16,16 +17,20 @@ from .utils import header_adder_interceptor +@dataclass class ModelRegistryAPIClient: """Model registry API.""" - def __init__( - self, + store: MLMDStore + + @classmethod + def secure_connection( + cls, server_address: str, port: int = 443, user_token: bytes | None = None, custom_ca: bytes | None = None, - ): + ) -> ModelRegistryAPIClient: """Constructor. Args: @@ -34,37 +39,39 @@ def __init__( user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH. custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. """ - if not user_token: - # /var/run/secrets/kubernetes.io/serviceaccount/token - sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH") - if sa_token: - user_token = Path(sa_token).read_bytes() - else: - warn("User access token is missing", stacklevel=2) - - if port == 443: - if not custom_ca: - ca_cert = os.environ.get("CERT") - if not ca_cert: - msg = "CA certificate must be provided" - raise StoreException(msg) - root_certs = Path(ca_cert).read_bytes() - else: - root_certs = custom_ca - chan_creds = grpc.ssl_channel_credentials(root_certs) - - if user_token: - call_creds = grpc.access_token_call_credentials(user_token) - chan_creds = grpc.composite_channel_credentials( - chan_creds, - call_creds, - ) - - chan = grpc.secure_channel( - f"{server_address}:443", + chan_creds = grpc.ssl_channel_credentials(custom_ca) + + if user_token: + chan_creds = grpc.composite_channel_credentials( chan_creds, + grpc.access_token_call_credentials(user_token), ) - elif user_token: + + if port != 443: + warn(f"Using non-standard port for TLS connection {port}", stacklevel=2) + + chan = grpc.secure_channel( + f"{server_address}:{port}", + chan_creds, + ) + + return cls(MLMDStore.from_channel(chan)) + + @classmethod + def insecure_connection( + cls, + server_address: str, + port: int, + user_token: bytes | None = None, + ) -> ModelRegistryAPIClient: + """Constructor. + + Args: + server_address: Server address. + port: Server port. + user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH. + """ + if user_token: chan = grpc.intercept_channel( grpc.insecure_channel(f"{server_address}:{port}"), # header key has to be lowercase @@ -73,7 +80,7 @@ def __init__( else: chan = grpc.insecure_channel(f"{server_address}:{port}") - self._store = MLMDStore.from_channel(chan) + return cls(MLMDStore.from_channel(chan)) def _map(self, py_obj: ProtoBase) -> ProtoType: """Map a Python object to a proto object. @@ -86,7 +93,7 @@ def _map(self, py_obj: ProtoBase) -> ProtoType: Returns: Proto object. """ - type_id = self._store.get_type_id( + type_id = self.store.get_type_id( py_obj.get_proto_type(), py_obj.get_proto_type_name() ) return py_obj.map(type_id) @@ -103,9 +110,9 @@ def upsert_registered_model(self, registered_model: RegisteredModel) -> str: Returns: ID of the registered model. """ - id = self._store.put_context(self._map(registered_model)) + id = self.store.put_context(self._map(registered_model)) new_py_rm = RegisteredModel.unmap( - self._store.get_context(RegisteredModel.get_proto_type_name(), id) + self.store.get_context(RegisteredModel.get_proto_type_name(), id) ) id = str(id) registered_model.id = id @@ -124,7 +131,7 @@ def get_registered_model_by_id(self, id: str) -> RegisteredModel | None: Returns: Registered model. """ - proto_rm = self._store.get_context( + proto_rm = self.store.get_context( RegisteredModel.get_proto_type_name(), id=int(id) ) if proto_rm is not None: @@ -150,7 +157,7 @@ def get_registered_model_by_params( if name is None and external_id is None: msg = "Either name or external_id must be provided" raise StoreException(msg) - proto_rm = self._store.get_context( + proto_rm = self.store.get_context( RegisteredModel.get_proto_type_name(), name=name, external_id=external_id, @@ -172,7 +179,7 @@ def get_registered_models( Registered models. """ mlmd_options = options.as_mlmd_list_options() if options else MLMDListOptions() - proto_rms = self._store.get_contexts( + proto_rms = self.store.get_contexts( RegisteredModel.get_proto_type_name(), mlmd_options ) return [RegisteredModel.unmap(proto_rm) for proto_rm in proto_rms] @@ -194,10 +201,10 @@ def upsert_model_version( """ # this is not ideal but we need this info for the prefix model_version._registered_model_id = registered_model_id - id = self._store.put_context(self._map(model_version)) - self._store.put_context_parent(int(registered_model_id), id) + id = self.store.put_context(self._map(model_version)) + self.store.put_context_parent(int(registered_model_id), id) new_py_mv = ModelVersion.unmap( - self._store.get_context(ModelVersion.get_proto_type_name(), id) + self.store.get_context(ModelVersion.get_proto_type_name(), id) ) id = str(id) model_version.id = id @@ -216,7 +223,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None: Returns: Model version. """ - proto_mv = self._store.get_context( + proto_mv = self.store.get_context( ModelVersion.get_proto_type_name(), id=int(model_version_id) ) if proto_mv is not None: @@ -240,7 +247,7 @@ def get_model_versions( mlmd_options.filter_query = f"parent_contexts_a.id = {registered_model_id}" return [ ModelVersion.unmap(proto_mv) - for proto_mv in self._store.get_contexts( + for proto_mv in self.store.get_contexts( ModelVersion.get_proto_type_name(), mlmd_options ) ] @@ -267,7 +274,7 @@ def get_model_version_by_params( StoreException: If neither external ID nor registered model ID and version is provided. """ if external_id is not None: - proto_mv = self._store.get_context( + proto_mv = self.store.get_context( ModelVersion.get_proto_type_name(), external_id=external_id ) elif registered_model_id is None or version is None: @@ -276,7 +283,7 @@ def get_model_version_by_params( ) raise StoreException(msg) else: - proto_mv = self._store.get_context( + proto_mv = self.store.get_context( ModelVersion.get_proto_type_name(), name=f"{registered_model_id}:{version}", ) @@ -304,17 +311,17 @@ def upsert_model_artifact( StoreException: If the model version already has a model artifact. """ mv_id = int(model_version_id) - if self._store.get_attributed_artifact( + if self.store.get_attributed_artifact( ModelArtifact.get_proto_type_name(), mv_id ): msg = f"Model version with ID {mv_id} already has a model artifact" raise StoreException(msg) model_artifact._model_version_id = model_version_id - id = self._store.put_artifact(self._map(model_artifact)) - self._store.put_attribution(mv_id, id) + id = self.store.put_artifact(self._map(model_artifact)) + self.store.put_attribution(mv_id, id) new_py_ma = ModelArtifact.unmap( - self._store.get_artifact(ModelArtifact.get_proto_type_name(), id) + self.store.get_artifact(ModelArtifact.get_proto_type_name(), id) ) id = str(id) model_artifact.id = id @@ -333,9 +340,7 @@ def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None: Returns: Model artifact. """ - proto_ma = self._store.get_artifact( - ModelArtifact.get_proto_type_name(), int(id) - ) + proto_ma = self.store.get_artifact(ModelArtifact.get_proto_type_name(), int(id)) if proto_ma is not None: return ModelArtifact.unmap(proto_ma) @@ -357,14 +362,14 @@ def get_model_artifact_by_params( StoreException: If neither external ID nor model version ID is provided. """ if external_id: - proto_ma = self._store.get_artifact( + proto_ma = self.store.get_artifact( ModelArtifact.get_proto_type_name(), external_id=external_id ) elif not model_version_id: msg = "Either model_version_id or external_id must be provided" raise StoreException(msg) else: - proto_ma = self._store.get_attributed_artifact( + proto_ma = self.store.get_attributed_artifact( ModelArtifact.get_proto_type_name(), int(model_version_id) ) if proto_ma is not None: @@ -390,7 +395,7 @@ def get_model_artifacts( if model_version_id is not None: mlmd_options.filter_query = f"contexts_a.id = {model_version_id}" - proto_mas = self._store.get_artifacts( + proto_mas = self.store.get_artifacts( ModelArtifact.get_proto_type_name(), mlmd_options ) return [ModelArtifact.unmap(proto_ma) for proto_ma in proto_mas] diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index ccc364d67..c821c5a81 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -171,7 +171,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: @pytest.fixture() def mr_api(store_wrapper: MLMDStore) -> ModelRegistryAPIClient: mr = object.__new__(ModelRegistryAPIClient) - mr._store = store_wrapper + mr.store = store_wrapper return mr diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index 2e4469d81..bfbcf33c4 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -61,7 +61,7 @@ def test_upsert_registered_model( ): mr_api.upsert_registered_model(registered_model.py) - rm_proto = mr_api._store.store.get_context_by_type_and_name( + rm_proto = mr_api.store.store.get_context_by_type_and_name( RegisteredModel.get_proto_type_name(), registered_model.proto.name ) assert rm_proto is not None @@ -73,7 +73,7 @@ def test_get_registered_model_by_id( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm_id = mr_api.store.store.put_contexts([registered_model.proto])[0] assert (mlmd_rm := mr_api.get_registered_model_by_id(str(rm_id))) assert mlmd_rm.id == str(rm_id) @@ -85,7 +85,7 @@ def test_get_registered_model_by_name( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm_id = mr_api.store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params(name=registered_model.py.name) @@ -101,7 +101,7 @@ def test_get_registered_model_by_external_id( ): registered_model.py.external_id = "external_id" registered_model.proto.external_id = "external_id" - rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm_id = mr_api.store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params( @@ -116,9 +116,9 @@ def test_get_registered_model_by_external_id( def test_get_registered_models( mr_api: ModelRegistryAPIClient, registered_model: Mapped ): - rm1_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm1_id = mr_api.store.store.put_contexts([registered_model.proto])[0] registered_model.proto.name = "model2" - rm2_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm2_id = mr_api.store.store.put_contexts([registered_model.proto])[0] mlmd_rms = mr_api.get_registered_models() assert len(mlmd_rms) == 2 @@ -130,12 +130,12 @@ def test_upsert_model_version( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm_id = mr_api.store.store.put_contexts([registered_model.proto])[0] rm_id = str(rm_id) mr_api.upsert_model_version(model_version.py, rm_id) - mv_proto = mr_api._store.store.get_context_by_type_and_name( + mv_proto = mr_api.store.store.get_context_by_type_and_name( ModelVersion.get_proto_type_name(), f"{rm_id}:{model_version.proto.name}" ) assert mv_proto is not None @@ -145,7 +145,7 @@ def test_upsert_model_version( def test_get_model_version_by_id(mr_api: ModelRegistryAPIClient, model_version: Mapped): model_version.proto.name = f"1:{model_version.proto.name}" - ctx_id = mr_api._store.store.put_contexts([model_version.proto])[0] + ctx_id = mr_api.store.store.put_contexts([model_version.proto])[0] id = str(ctx_id) assert (mlmd_mv := mr_api.get_model_version_by_id(id)) @@ -158,7 +158,7 @@ def test_get_model_version_by_name( mr_api: ModelRegistryAPIClient, model_version: Mapped ): model_version.proto.name = f"1:{model_version.proto.name}" - mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv_id = mr_api.store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -176,7 +176,7 @@ def test_get_model_version_by_external_id( model_version.proto.name = f"1:{model_version.proto.name}" model_version.proto.external_id = "external_id" model_version.py.external_id = "external_id" - mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv_id = mr_api.store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -193,14 +193,14 @@ def test_get_model_versions( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] + rm_id = mr_api.store.store.put_contexts([registered_model.proto])[0] model_version.proto.name = f"{rm_id}:version" - mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv1_id = mr_api.store.store.put_contexts([model_version.proto])[0] model_version.proto.name = f"{rm_id}:version2" - mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv2_id = mr_api.store.store.put_contexts([model_version.proto])[0] - mr_api._store.store.put_parent_contexts( + mr_api.store.store.put_parent_contexts( [ ParentContext(parent_id=rm_id, child_id=mv1_id), ParentContext(parent_id=rm_id, child_id=mv2_id), @@ -220,12 +220,12 @@ def test_upsert_model_artifact( ): monkeypatch.setattr(ModelArtifact, "mlmd_name_prefix", "test_prefix") - mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv_id = mr_api.store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) mr_api.upsert_model_artifact(model.py, mv_id) - ma_proto = mr_api._store.store.get_artifact_by_type_and_name( + ma_proto = mr_api.store.store.get_artifact_by_type_and_name( ModelArtifact.get_proto_type_name(), f"test_prefix:{model.proto.name}" ) assert ma_proto is not None @@ -236,11 +236,11 @@ def test_upsert_model_artifact( def test_upsert_duplicate_model_artifact_with_different_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv1_id = mr_api.store.store.put_contexts([model_version.proto])[0] mv1_id = str(mv1_id) model_version.proto.name = "version2" - mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv2_id = mr_api.store.store.put_contexts([model_version.proto])[0] mv2_id = str(mv2_id) ma1 = evolve(model.py) @@ -248,7 +248,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( ma2 = evolve(model.py) mr_api.upsert_model_artifact(ma2, mv2_id) - ma_protos = mr_api._store.store.get_artifacts_by_id([int(ma1.id), int(ma2.id)]) + ma_protos = mr_api.store.store.get_artifacts_by_id([int(ma1.id), int(ma2.id)]) assert ma1.name == ma2.name assert ma1.name != str(ma_protos[0].name) assert ma2.name != str(ma_protos[1].name) @@ -257,7 +257,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( def test_upsert_duplicate_model_artifact_with_same_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv_id = mr_api.store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) ma1 = evolve(model.py) @@ -269,7 +269,7 @@ def test_upsert_duplicate_model_artifact_with_same_version( def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = f"test_prefix:{model.proto.name}" - id = mr_api._store.store.put_artifacts([model.proto])[0] + id = mr_api.store.store.put_artifacts([model.proto])[0] id = str(id) assert (mlmd_ma := mr_api.get_model_artifact_by_id(id)) @@ -281,12 +281,12 @@ def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped) def test_get_model_artifact_by_model_version_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv_id = mr_api.store.store.put_contexts([model_version.proto])[0] model.proto.name = f"test_prefix:{model.proto.name}" - ma_id = mr_api._store.store.put_artifacts([model.proto])[0] + ma_id = mr_api.store.store.put_artifacts([model.proto])[0] - mr_api._store.store.put_attributions_and_associations( + mr_api.store.store.put_attributions_and_associations( [Attribution(context_id=mv_id, artifact_id=ma_id)], [] ) @@ -303,7 +303,7 @@ def test_get_model_artifact_by_external_id( model.proto.external_id = "external_id" model.py.external_id = "external_id" - id = mr_api._store.store.put_artifacts([model.proto])[0] + id = mr_api.store.store.put_artifacts([model.proto])[0] id = str(id) assert ( @@ -316,9 +316,9 @@ def test_get_model_artifact_by_external_id( def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] + ma1_id = mr_api.store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] + ma2_id = mr_api.store.store.put_artifacts([model.proto])[0] mlmd_mas = mr_api.get_model_artifacts() assert len(mlmd_mas) == 2 @@ -328,17 +328,17 @@ def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): def test_get_model_artifacts_by_mv_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv1_id = mr_api.store.store.put_contexts([model_version.proto])[0] model_version.proto.name = "version2" - mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] + mv2_id = mr_api.store.store.put_contexts([model_version.proto])[0] model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] + ma1_id = mr_api.store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] + ma2_id = mr_api.store.store.put_artifacts([model.proto])[0] - mr_api._store.store.put_attributions_and_associations( + mr_api.store.store.put_attributions_and_associations( [ Attribution(context_id=mv1_id, artifact_id=ma1_id), Attribution(context_id=mv2_id, artifact_id=ma2_id), diff --git a/test/robot/ModelRegistry.py b/test/robot/ModelRegistry.py index 4b5d1090e..861249532 100644 --- a/test/robot/ModelRegistry.py +++ b/test/robot/ModelRegistry.py @@ -10,7 +10,7 @@ def write_to_console(s): class ModelRegistry(mr.core.ModelRegistryAPIClient): def __init__(self, host: str = "localhost", port: int = 9090): - super().__init__(host, port) + super().__init__(mr.store.MLMDStore.from_config(host, port)) def upsert_registered_model(self, registered_model) -> str: p = RegisteredModel("")