diff --git a/clients/python/README.md b/clients/python/README.md index 6af08cf10..ddff13e5b 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -10,7 +10,7 @@ This library provides a high level interface for interacting with a model regist ```py from model_registry import ModelRegistry -registry = ModelRegistry(server_address="server-address", port=9090, author="author") +registry = ModelRegistry("server-address", author="Ada Lovelace") # Defaults to port 443 model = registry.register_model( "my-model", # model name diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 5a3074e95..b59b1ec96 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -17,27 +17,25 @@ class ModelRegistry: def __init__( self, server_address: str, - port: int, + port: int = 443, + *, author: str, - client_key: str | None = None, - server_cert: str | None = None, - custom_ca: str | None = None, + custom_ca: bytes | None = None, ): """Constructor. Args: server_address: Server address. - port: Server port. + port: Server port. Defaults to 443. + custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. + + Keyword Args: author: Name of the author. - client_key: The PEM-encoded private key as a byte string. - server_cert: The PEM-encoded certificate as a byte string. - custom_ca: The PEM-encoded root certificates as a byte string. + custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT. """ # TODO: get args from env self._author = author - self._api = ModelRegistryAPIClient( - server_address, port, client_key, server_cert, custom_ca - ) + self._api = ModelRegistryAPIClient(server_address, port, custom_ca) def _register_model(self, name: str) -> RegisteredModel: if rm := self._api.get_registered_model_by_params(name): diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 9e57da09c..5f83fc5f9 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -2,6 +2,9 @@ from __future__ import annotations +import os +from pathlib import Path + from ml_metadata.proto import MetadataStoreClientConfig from .exceptions import StoreException @@ -17,29 +20,30 @@ class ModelRegistryAPIClient: def __init__( self, server_address: str, - port: int, - client_key: str | None = None, - server_cert: str | None = None, - custom_ca: str | None = None, + port: int = 443, + custom_ca: bytes | None = None, ): """Constructor. Args: server_address: Server address. - port: Server port. - client_key: The PEM-encoded private key as a byte string. - server_cert: The PEM-encoded certificate as a byte string. - custom_ca: The PEM-encoded root certificates as a byte string. + custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. + port: Server port. Defaults to 443. """ config = MetadataStoreClientConfig() + if port == 443: + if not custom_ca: + ca_cert = os.environ.get("CERT") + if not ca_cert: + msg = "CA certificate must be provided" + raise StoreException(msg) + root_certs = Path(ca_cert).read_bytes() + else: + root_certs = custom_ca + + config.ssl_config.custom_ca = root_certs config.host = server_address config.port = port - if client_key is not None: - config.ssl_config.client_key = client_key - if server_cert is not None: - config.ssl_config.server_cert = server_cert - if custom_ca is not None: - config.ssl_config.custom_ca = custom_ca self._store = MLMDStore(config) def _map(self, py_obj: ProtoBase) -> ProtoType: