diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index b59b1ec96..c38f1cc9a 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -20,6 +20,7 @@ def __init__( port: int = 443, *, author: str, + user_token: bytes | None = None, custom_ca: bytes | None = None, ): """Constructor. @@ -27,15 +28,15 @@ def __init__( Args: server_address: Server address. port: Server port. Defaults to 443. - custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. Keyword Args: author: Name of the author. + user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH. custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT. """ # TODO: get args from env self._author = author - self._api = ModelRegistryAPIClient(server_address, port, custom_ca) + self._api = ModelRegistryAPIClient(server_address, port, user_token, custom_ca) def _register_model(self, name: str) -> RegisteredModel: if rm := self._api.get_registered_model_by_params(name): diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 5f83fc5f9..4d9579aeb 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -4,14 +4,16 @@ import os from pathlib import Path +from warnings import warn -from ml_metadata.proto import MetadataStoreClientConfig +import grpc from .exceptions import StoreException from .store import MLMDStore, ProtoType from .types import ListOptions, ModelArtifact, ModelVersion, RegisteredModel from .types.base import ProtoBase from .types.options import MLMDListOptions +from .utils import header_adder_interceptor class ModelRegistryAPIClient: @@ -21,16 +23,25 @@ def __init__( self, server_address: str, port: int = 443, + user_token: bytes | None = None, custom_ca: bytes | None = None, ): """Constructor. Args: server_address: Server address. - custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. port: Server port. Defaults to 443. + user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH. + custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. """ - config = MetadataStoreClientConfig() + if not user_token: + # /var/run/secrets/kubernetes.io/serviceaccount/token + sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH") + if sa_token: + user_token = Path(sa_token).read_bytes() + else: + warn("User access token is missing", stacklevel=2) + if port == 443: if not custom_ca: ca_cert = os.environ.get("CERT") @@ -40,11 +51,29 @@ def __init__( root_certs = Path(ca_cert).read_bytes() else: root_certs = custom_ca + chan_creds = grpc.ssl_channel_credentials(root_certs) + + if user_token: + call_creds = grpc.access_token_call_credentials(user_token) + chan_creds = grpc.composite_channel_credentials( + chan_creds, + call_creds, + ) + + chan = grpc.secure_channel( + f"{server_address}:443", + chan_creds, + ) + elif user_token: + chan = grpc.intercept_channel( + grpc.insecure_channel(f"{server_address}:{port}"), + # header key has to be lowercase + header_adder_interceptor("authorization", f"Bearer {user_token}"), + ) + else: + chan = grpc.insecure_channel(f"{server_address}:{port}") - config.ssl_config.custom_ca = root_certs - config.host = server_address - config.port = port - self._store = MLMDStore(config) + self._store = MLMDStore.from_channel(chan) def _map(self, py_obj: ProtoBase) -> ProtoType: """Map a Python object to a proto object. diff --git a/clients/python/src/model_registry/store/wrapper.py b/clients/python/src/model_registry/store/wrapper.py index 750d0642d..dd291eeb8 100644 --- a/clients/python/src/model_registry/store/wrapper.py +++ b/clients/python/src/model_registry/store/wrapper.py @@ -3,8 +3,10 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass from typing import ClassVar +from grpc import Channel from ml_metadata import errors from ml_metadata.metadata_store import ListOptions, MetadataStore from ml_metadata.proto import ( @@ -14,6 +16,7 @@ MetadataStoreClientConfig, ParentContext, ) +from ml_metadata.proto.metadata_store_service_pb2_grpc import MetadataStoreServiceStub from model_registry.exceptions import ( DuplicateException, @@ -25,19 +28,43 @@ from .base import ProtoType +@dataclass class MLMDStore: """MLMD storage backend.""" + store: MetadataStore # cache for MLMD type IDs _type_ids: ClassVar[dict[str, int]] = {} - def __init__(self, config: MetadataStoreClientConfig): + @classmethod + def from_config(cls, host: str, port: int): """Constructor. Args: - config: MLMD config. + host: MLMD store server host. + port: MLMD store server port. """ - self._mlmd_store = MetadataStore(config) + return cls( + MetadataStore( + MetadataStoreClientConfig( + host=host, + port=port, + ) + ) + ) + + @classmethod + def from_channel(cls, chan: Channel): + """Constructor. + + Args: + chan: gRPC channel to the MLMD store. + """ + store = MetadataStore( + MetadataStoreClientConfig(host="localhost", port=8080), + ) + store._metadata_store_stub = MetadataStoreServiceStub(chan) + return cls(store) def get_type_id(self, pt: type[ProtoType], type_name: str) -> int: """Get backend ID for a type. @@ -59,7 +86,7 @@ def get_type_id(self, pt: type[ProtoType], type_name: str) -> int: pt_name = pt.__name__.lower() try: - _type = getattr(self._mlmd_store, f"get_{pt_name}_type")(type_name) + _type = getattr(self.store, f"get_{pt_name}_type")(type_name) except errors.NotFoundError as e: msg = f"{pt_name} type {type_name} does not exist" raise TypeNotFoundException(msg) from e @@ -85,7 +112,7 @@ def put_artifact(self, artifact: Artifact) -> int: StoreException: If the artifact isn't properly formed. """ try: - return self._mlmd_store.put_artifacts([artifact])[0] + return self.store.put_artifacts([artifact])[0] except errors.AlreadyExistsError as e: msg = f"Artifact {artifact.name} already exists" raise DuplicateException(msg) from e @@ -111,7 +138,7 @@ def put_context(self, context: Context) -> int: StoreException: If the context isn't propertly formed. """ try: - return self._mlmd_store.put_contexts([context])[0] + return self.store.put_contexts([context])[0] except errors.AlreadyExistsError as e: msg = f"Context {context.name} already exists" raise DuplicateException(msg) from e @@ -152,12 +179,12 @@ def get_context( StoreException: Invalid arguments. """ if name is not None: - return self._mlmd_store.get_context_by_type_and_name(ctx_type_name, name) + return self.store.get_context_by_type_and_name(ctx_type_name, name) if id is not None: - contexts = self._mlmd_store.get_contexts_by_id([id]) + contexts = self.store.get_contexts_by_id([id]) elif external_id is not None: - contexts = self._mlmd_store.get_contexts_by_external_ids([external_id]) + contexts = self.store.get_contexts_by_external_ids([external_id]) else: msg = "Either id, name or external_id must be provided" raise StoreException(msg) @@ -186,7 +213,7 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context # TODO: should we make options optional? # if options is not None: try: - contexts = self._mlmd_store.get_contexts(options) + contexts = self.store.get_contexts(options) except errors.InvalidArgumentError as e: msg = f"Invalid arguments for get_contexts: {e}" raise StoreException(msg) from e @@ -196,10 +223,10 @@ def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context contexts = self._filter_type(ctx_type_name, contexts) # else: - # contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name) + # contexts = self.store.get_contexts_by_type(ctx_type_name) if not contexts and ctx_type_name not in [ - t.name for t in self._mlmd_store.get_context_types() + t.name for t in self.store.get_context_types() ]: msg = f"Context type {ctx_type_name} does not exist" raise TypeNotFoundException(msg) @@ -218,7 +245,7 @@ def put_context_parent(self, parent_id: int, child_id: int): ServerException: If there was an error putting the parent context. """ try: - self._mlmd_store.put_parent_contexts( + self.store.put_parent_contexts( [ParentContext(parent_id=parent_id, child_id=child_id)] ) except errors.AlreadyExistsError as e: @@ -240,7 +267,7 @@ def put_attribution(self, context_id: int, artifact_id: int): """ attribution = Attribution(context_id=context_id, artifact_id=artifact_id) try: - self._mlmd_store.put_attributions_and_associations([attribution], []) + self.store.put_attributions_and_associations([attribution], []) except errors.InvalidArgumentError as e: if "artifact" in str(e).lower(): msg = f"Artifact with ID {artifact_id} does not exist" @@ -277,12 +304,12 @@ def get_artifact( StoreException: Invalid arguments. """ if name is not None: - return self._mlmd_store.get_artifact_by_type_and_name(art_type_name, name) + return self.store.get_artifact_by_type_and_name(art_type_name, name) if id is not None: - artifacts = self._mlmd_store.get_artifacts_by_id([id]) + artifacts = self.store.get_artifacts_by_id([id]) elif external_id is not None: - artifacts = self._mlmd_store.get_artifacts_by_external_ids([external_id]) + artifacts = self.store.get_artifacts_by_external_ids([external_id]) else: msg = "Either id, name or external_id must be provided" raise StoreException(msg) @@ -304,7 +331,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact: Artifact. """ try: - artifacts = self._mlmd_store.get_artifacts_by_context(ctx_id) + artifacts = self.store.get_artifacts_by_context(ctx_id) except errors.InternalError as e: msg = f"Couldn't get artifacts by context {ctx_id}" raise ServerException(msg) from e @@ -330,7 +357,7 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa StoreException: Invalid arguments. """ try: - artifacts = self._mlmd_store.get_artifacts(options) + artifacts = self.store.get_artifacts(options) except errors.InvalidArgumentError as e: msg = f"Invalid arguments for get_artifacts: {e}" raise StoreException(msg) from e @@ -340,7 +367,7 @@ def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifa artifacts = self._filter_type(art_type_name, artifacts) if not artifacts and art_type_name not in [ - t.name for t in self._mlmd_store.get_artifact_types() + t.name for t in self.store.get_artifact_types() ]: msg = f"Artifact type {art_type_name} does not exist" raise TypeNotFoundException(msg) diff --git a/clients/python/src/model_registry/utils.py b/clients/python/src/model_registry/utils.py index e60dcf5dd..1deb67a62 100644 --- a/clients/python/src/model_registry/utils.py +++ b/clients/python/src/model_registry/utils.py @@ -3,7 +3,11 @@ from __future__ import annotations import os +from collections import namedtuple +from typing import Callable +import grpc +from attr import dataclass from typing_extensions import overload from ._utils import required_args @@ -90,3 +94,85 @@ def s3_uri_from( # https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls # FIXME: is this safe? return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}" + + +# https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py +@dataclass +class GenericClientInterceptor( # noqa: D101 + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + fn: Callable + + def intercept_unary_unary(self, continuation, client_call_details, request): # noqa: D102 + new_details, new_request_iterator, postprocess = self.fn( + client_call_details, iter((request,)), False, False + ) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, request): # noqa: D102 + new_details, new_request_iterator, postprocess = self.fn( + client_call_details, iter((request,)), False, True + ) + response_it = continuation(new_details, next(new_request_iterator)) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary( # noqa: D102 + self, continuation, client_call_details, request_iterator + ): + new_details, new_request_iterator, postprocess = self.fn( + client_call_details, request_iterator, True, False + ) + response = continuation(new_details, new_request_iterator) + return postprocess(response) if postprocess else response + + def intercept_stream_stream( # noqa: D102 + self, continuation, client_call_details, request_iterator + ): + new_details, new_request_iterator, postprocess = self.fn( + client_call_details, request_iterator, True, True + ) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +# https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/header_manipulator_client_interceptor.py +# we need to subclass ClientCallDetails to add a constructor (it's ABC) +class ClientCallDetails( # noqa: D101 + namedtuple("ClientCallDetails", ("method", "timeout", "metadata", "credentials")), + grpc.ClientCallDetails, +): + pass + + +def header_adder_interceptor(header, value): + """Create a client interceptor that adds a header to requests.""" + + def intercept_call( + client_call_details, + request_iterator, + request_streaming, + response_streaming, + ): + metadata = list(client_call_details.metadata or []) + metadata.append( + ( + header, + value, + ) + ) + return ( + ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ), + request_iterator, + None, + ) + + return GenericClientInterceptor(intercept_call) diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 8abb31a7b..ccc364d67 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -21,7 +21,7 @@ # ruff: noqa: PT021 supported @pytest.fixture(scope="session") -def mlmd_conn(request) -> MetadataStoreClientConfig: +def mlmd_port(request) -> int: model_registry_root_dir = model_registry_root(request) print( "Assuming this is the Model Registry root directory:", model_registry_root_dir @@ -46,10 +46,8 @@ def mlmd_conn(request) -> MetadataStoreClientConfig: wait_for_logs(container, "Server listening on") os.system('docker container ls --format "table {{.ID}}\t{{.Names}}\t{{.Ports}}" -a') # noqa governed test print("waited for logs and port") - cfg = MetadataStoreClientConfig( - host="localhost", port=int(container.get_exposed_port(8080)) - ) - print(cfg) + port = int(container.get_exposed_port(8080)) + print("port:", port) # this callback is needed in order to perform the container.stop() # removing this callback might result in mlmd container shutting down before the tests had chance to fully run, @@ -63,10 +61,12 @@ def teardown(): time.sleep( 3 ) # allowing some time for mlmd grpc to fully stabilize (is "spent" once per pytest session anyway) - _throwaway_store = metadata_store.MetadataStore(cfg) + _throwaway_store = metadata_store.MetadataStore( + MetadataStoreClientConfig(host="localhost", port=port) + ) wait_for_grpc(container, _throwaway_store) - return cfg + return port def model_registry_root(request): @@ -74,7 +74,7 @@ def model_registry_root(request): @pytest.fixture() -def plain_wrapper(request, mlmd_conn: MetadataStoreClientConfig) -> MLMDStore: +def plain_wrapper(request, mlmd_port: int) -> MLMDStore: sqlite_db_file = ( model_registry_root(request) / "test/config/ml-metadata/metadata.sqlite.db" ) @@ -89,7 +89,7 @@ def teardown(): request.addfinalizer(teardown) - to_return = MLMDStore(mlmd_conn) + to_return = MLMDStore.from_config("localhost", mlmd_port) sanity_check_mlmd_connection_to_db(to_return) return to_return @@ -109,10 +109,13 @@ def sanity_check_mlmd_connection_to_db(overview: MLMDStore): while retry_count < 3: retry_count += 1 try: - overview._mlmd_store.get_artifact_types() + overview.store.get_artifact_types() return except Exception as e: - if str(e) == "Cannot connect sqlite3 database: unable to open database file": + if ( + str(e) + == "Cannot connect sqlite3 database: unable to open database file" + ): time.sleep(1) else: msg = "Failed to sanity check before each test, another type of error detected." @@ -136,7 +139,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_artifact_type(ma_type) + plain_wrapper.store.put_artifact_type(ma_type) mv_type = set_type_attrs( ContextType(), @@ -149,7 +152,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_context_type(mv_type) + plain_wrapper.store.put_context_type(mv_type) rm_type = set_type_attrs( ContextType(), @@ -160,7 +163,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_context_type(rm_type) + plain_wrapper.store.put_context_type(rm_type) return plain_wrapper diff --git a/clients/python/tests/store/test_wrapper.py b/clients/python/tests/store/test_wrapper.py index c6f4dbe2d..1ff1d2239 100644 --- a/clients/python/tests/store/test_wrapper.py +++ b/clients/python/tests/store/test_wrapper.py @@ -27,7 +27,7 @@ def artifact(plain_wrapper: MLMDStore) -> Artifact: art = Artifact() art.name = "test_artifact" - art.type_id = plain_wrapper._mlmd_store.put_artifact_type(art_type) + art.type_id = plain_wrapper.store.put_artifact_type(art_type) return art @@ -39,7 +39,7 @@ def context(plain_wrapper: MLMDStore) -> Context: ctx = Context() ctx.name = "test_context" - ctx.type_id = plain_wrapper._mlmd_store.put_context_type(ctx_type) + ctx.type_id = plain_wrapper.store.put_context_type(ctx_type) return ctx @@ -84,7 +84,7 @@ def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact): def test_put_duplicate_artifact(plain_wrapper: MLMDStore, artifact: Artifact): - plain_wrapper._mlmd_store.put_artifacts([artifact]) + plain_wrapper.store.put_artifacts([artifact]) with pytest.raises(DuplicateException): plain_wrapper.put_artifact(artifact) @@ -97,7 +97,7 @@ def test_put_invalid_context(plain_wrapper: MLMDStore, context: Context): def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context): - plain_wrapper._mlmd_store.put_contexts([context]) + plain_wrapper.store.put_contexts([context]) with pytest.raises(DuplicateException): plain_wrapper.put_context(context) @@ -106,7 +106,7 @@ def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context): def test_put_attribution_with_invalid_context( plain_wrapper: MLMDStore, artifact: Artifact ): - art_id = plain_wrapper._mlmd_store.put_artifacts([artifact])[0] + art_id = plain_wrapper.store.put_artifacts([artifact])[0] with pytest.raises(StoreException) as store_error: plain_wrapper.put_attribution(0, art_id) @@ -117,7 +117,7 @@ def test_put_attribution_with_invalid_context( def test_put_attribution_with_invalid_artifact( plain_wrapper: MLMDStore, context: Context ): - ctx_id = plain_wrapper._mlmd_store.put_contexts([context])[0] + ctx_id = plain_wrapper.store.put_contexts([context])[0] with pytest.raises(StoreException) as store_error: plain_wrapper.put_attribution(ctx_id, 0) diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index bdb3ba144..2e4469d81 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -61,7 +61,7 @@ def test_upsert_registered_model( ): mr_api.upsert_registered_model(registered_model.py) - rm_proto = mr_api._store._mlmd_store.get_context_by_type_and_name( + rm_proto = mr_api._store.store.get_context_by_type_and_name( RegisteredModel.get_proto_type_name(), registered_model.proto.name ) assert rm_proto is not None @@ -73,7 +73,7 @@ def test_get_registered_model_by_id( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert (mlmd_rm := mr_api.get_registered_model_by_id(str(rm_id))) assert mlmd_rm.id == str(rm_id) @@ -85,7 +85,7 @@ def test_get_registered_model_by_name( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params(name=registered_model.py.name) @@ -101,7 +101,7 @@ def test_get_registered_model_by_external_id( ): registered_model.py.external_id = "external_id" registered_model.proto.external_id = "external_id" - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params( @@ -116,9 +116,9 @@ def test_get_registered_model_by_external_id( def test_get_registered_models( mr_api: ModelRegistryAPIClient, registered_model: Mapped ): - rm1_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm1_id = mr_api._store.store.put_contexts([registered_model.proto])[0] registered_model.proto.name = "model2" - rm2_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm2_id = mr_api._store.store.put_contexts([registered_model.proto])[0] mlmd_rms = mr_api.get_registered_models() assert len(mlmd_rms) == 2 @@ -130,12 +130,12 @@ def test_upsert_model_version( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] rm_id = str(rm_id) mr_api.upsert_model_version(model_version.py, rm_id) - mv_proto = mr_api._store._mlmd_store.get_context_by_type_and_name( + mv_proto = mr_api._store.store.get_context_by_type_and_name( ModelVersion.get_proto_type_name(), f"{rm_id}:{model_version.proto.name}" ) assert mv_proto is not None @@ -145,7 +145,7 @@ def test_upsert_model_version( def test_get_model_version_by_id(mr_api: ModelRegistryAPIClient, model_version: Mapped): model_version.proto.name = f"1:{model_version.proto.name}" - ctx_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + ctx_id = mr_api._store.store.put_contexts([model_version.proto])[0] id = str(ctx_id) assert (mlmd_mv := mr_api.get_model_version_by_id(id)) @@ -158,7 +158,7 @@ def test_get_model_version_by_name( mr_api: ModelRegistryAPIClient, model_version: Mapped ): model_version.proto.name = f"1:{model_version.proto.name}" - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -176,7 +176,7 @@ def test_get_model_version_by_external_id( model_version.proto.name = f"1:{model_version.proto.name}" model_version.proto.external_id = "external_id" model_version.py.external_id = "external_id" - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -193,14 +193,14 @@ def test_get_model_versions( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] model_version.proto.name = f"{rm_id}:version" - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] model_version.proto.name = f"{rm_id}:version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] - mr_api._store._mlmd_store.put_parent_contexts( + mr_api._store.store.put_parent_contexts( [ ParentContext(parent_id=rm_id, child_id=mv1_id), ParentContext(parent_id=rm_id, child_id=mv2_id), @@ -220,12 +220,12 @@ def test_upsert_model_artifact( ): monkeypatch.setattr(ModelArtifact, "mlmd_name_prefix", "test_prefix") - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) mr_api.upsert_model_artifact(model.py, mv_id) - ma_proto = mr_api._store._mlmd_store.get_artifact_by_type_and_name( + ma_proto = mr_api._store.store.get_artifact_by_type_and_name( ModelArtifact.get_proto_type_name(), f"test_prefix:{model.proto.name}" ) assert ma_proto is not None @@ -236,11 +236,11 @@ def test_upsert_model_artifact( def test_upsert_duplicate_model_artifact_with_different_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv1_id = str(mv1_id) model_version.proto.name = "version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv2_id = str(mv2_id) ma1 = evolve(model.py) @@ -248,9 +248,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( ma2 = evolve(model.py) mr_api.upsert_model_artifact(ma2, mv2_id) - ma_protos = mr_api._store._mlmd_store.get_artifacts_by_id( - [int(ma1.id), int(ma2.id)] - ) + ma_protos = mr_api._store.store.get_artifacts_by_id([int(ma1.id), int(ma2.id)]) assert ma1.name == ma2.name assert ma1.name != str(ma_protos[0].name) assert ma2.name != str(ma_protos[1].name) @@ -259,7 +257,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( def test_upsert_duplicate_model_artifact_with_same_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) ma1 = evolve(model.py) @@ -271,7 +269,7 @@ def test_upsert_duplicate_model_artifact_with_same_version( def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = f"test_prefix:{model.proto.name}" - id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + id = mr_api._store.store.put_artifacts([model.proto])[0] id = str(id) assert (mlmd_ma := mr_api.get_model_artifact_by_id(id)) @@ -283,12 +281,12 @@ def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped) def test_get_model_artifact_by_model_version_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] model.proto.name = f"test_prefix:{model.proto.name}" - ma_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma_id = mr_api._store.store.put_artifacts([model.proto])[0] - mr_api._store._mlmd_store.put_attributions_and_associations( + mr_api._store.store.put_attributions_and_associations( [Attribution(context_id=mv_id, artifact_id=ma_id)], [] ) @@ -305,7 +303,7 @@ def test_get_model_artifact_by_external_id( model.proto.external_id = "external_id" model.py.external_id = "external_id" - id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + id = mr_api._store.store.put_artifacts([model.proto])[0] id = str(id) assert ( @@ -318,9 +316,9 @@ def test_get_model_artifact_by_external_id( def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] mlmd_mas = mr_api.get_model_artifacts() assert len(mlmd_mas) == 2 @@ -330,17 +328,17 @@ def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): def test_get_model_artifacts_by_mv_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] model_version.proto.name = "version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] - mr_api._store._mlmd_store.put_attributions_and_associations( + mr_api._store.store.put_attributions_and_associations( [ Attribution(context_id=mv1_id, artifact_id=ma1_id), Attribution(context_id=mv2_id, artifact_id=ma2_id),