Skip to content

Commit

Permalink
py: enable TLS auth by default
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed May 2, 2024
1 parent ab5ffa9 commit 6f9234a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
2 changes: 1 addition & 1 deletion clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This library provides a high level interface for interacting with a model regist
```py
from model_registry import ModelRegistry

registry = ModelRegistry(server_address="server-address", port=9090, author="author")
registry = ModelRegistry("server-address", author="Ada Lovelace") # Defaults to port 443

model = registry.register_model(
"my-model", # model name
Expand Down
20 changes: 9 additions & 11 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,25 @@ class ModelRegistry:
def __init__(
self,
server_address: str,
port: int,
port: int = 443,
*,
author: str,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
custom_ca: bytes | None = None,
):
"""Constructor.
Args:
server_address: Server address.
port: Server port.
port: Server port. Defaults to 443.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
Keyword Args:
author: Name of the author.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
"""
# TODO: get args from env
self._author = author
self._api = ModelRegistryAPIClient(
server_address, port, client_key, server_cert, custom_ca
)
self._api = ModelRegistryAPIClient(server_address, port, custom_ca)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
32 changes: 18 additions & 14 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from __future__ import annotations

import os
from pathlib import Path

from ml_metadata.proto import MetadataStoreClientConfig

from .exceptions import StoreException
Expand All @@ -17,29 +20,30 @@ class ModelRegistryAPIClient:
def __init__(
self,
server_address: str,
port: int,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
port: int = 443,
custom_ca: bytes | None = None,
):
"""Constructor.
Args:
server_address: Server address.
port: Server port.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
port: Server port. Defaults to 443.
"""
config = MetadataStoreClientConfig()
if port == 443:
if not custom_ca:
ca_cert = os.environ.get("CERT")
if not ca_cert:
msg = "CA certificate must be provided"
raise StoreException(msg)
root_certs = Path(ca_cert).read_bytes()
else:
root_certs = custom_ca

config.ssl_config.custom_ca = root_certs
config.host = server_address
config.port = port
if client_key is not None:
config.ssl_config.client_key = client_key
if server_cert is not None:
config.ssl_config.server_cert = server_cert
if custom_ca is not None:
config.ssl_config.custom_ca = custom_ca
self._store = MLMDStore(config)

def _map(self, py_obj: ProtoBase) -> ProtoType:
Expand Down

0 comments on commit 6f9234a

Please sign in to comment.