Skip to content

Commit

Permalink
py: abstract kc API
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Nov 11, 2024
1 parent ce0c6aa commit c01f397
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 58 deletions.
90 changes: 35 additions & 55 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
ModelTypes = t.Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = t.TypeVar("TModel", bound=ModelTypes)

DSC_CRD = "datasciencecluster.opendatahub.io/v1"
DEFAULT_NS = "kubeflow"
DSC_NS_CONFIG = "registriesNamespace"
EXTERNAL_ADDR_ANNOTATION = "routing.opendatahub.io/external-address-rest"


class ModelRegistry:
"""Model registry client."""
Expand Down Expand Up @@ -94,7 +89,14 @@ def __init__(

@classmethod
def from_service(
cls, name: str, author: str, *, ns: str | None = None, is_secure: bool = True
cls,
name: str,
author: str,
*,
ns: str | None = None,
is_secure: bool = True,
user_token: str | None = None,
custom_ca: str | None = None,
) -> ModelRegistry:
"""Create a client from a service name.
Expand All @@ -105,61 +107,39 @@ def from_service(
Keyword Args:
ns: Namespace. Defaults to DSC registriesNamespace, or `kubeflow` if unavailable.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT.
"""
from kubernetes import client, config

config.load_incluster_config()
if not ns:
kcustom = client.CustomObjectsApi()
g, v = DSC_CRD.split("/")
p = f"{g.split('.')[0]}s"
try:
dsc_raw = kcustom.list_cluster_custom_object(
group=g,
version=v,
plural=p,
)
except client.ApiException as e:
msg = f"Failed to list {p}: {e}"
warn(msg, stacklevel=2)
ns = DEFAULT_NS
else:
ns = t.cast(
dict[str, t.Any],
dsc_raw["items"][0],
)["status"]["components"]["modelregistry"][DSC_NS_CONFIG]

kcore = client.CoreV1Api()
serv = t.cast(client.V1Service, kcore.read_namespaced_service(name, ns))
meta = t.cast(client.V1ObjectMeta, serv.metadata)
ext_addr = t.cast(dict[str, str], meta.annotations).get(
EXTERNAL_ADDR_ANNOTATION
)
if ext_addr:
host, port = ext_addr.split(":")
host = f"https://{host}"
port = int(port)
elif not is_secure:
host = f"http://{meta.name}"
port = next(
(
int(str(port.port))
for port in t.cast(
list[client.V1ServicePort],
t.cast(client.V1ServiceSpec, serv.spec).ports,
)
if port.app_protocol == "http"
),
8080,
)
else:
msg = "No external address found for secure connection"
raise StoreError(msg)
from ._utils import Address, Kube

with Kube(user_token) as kc:
if not ns:
res = kc.get_mr_ns()
if e := res.error:
warn(str(e), stacklevel=2)
ns = res.value
assert isinstance(ns, str)

res = kc.get_service_addr(name, ns)
if e := res.error:
if not res.value:
raise e
warn(str(e), stacklevel=2)
addr = res.value
assert isinstance(addr, Address)
if addr.protocol != "https" and is_secure:
msg = "Service does not support secure connection. To proceed with insecure connection, set is_secure=False"
raise StoreError(msg)
host = f"{addr.protocol}://{addr.host}"
port = addr.port

return cls(
host,
port,
author=author,
is_secure=is_secure,
user_token=user_token,
custom_ca=custom_ca,
)

def async_runner(self, coro: t.Any) -> t.Any:
Expand Down
168 changes: 165 additions & 3 deletions clients/python/src/model_registry/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import functools
import inspect
import typing as t
from collections.abc import Sequence
from typing import Any, Callable, TypeVar
from dataclasses import dataclass

CallableT = TypeVar("CallableT", bound=Callable[..., Any])
from .exceptions import StoreError

CallableT = t.TypeVar("CallableT", bound=t.Callable[..., t.Any])


# copied from https://github.com/Rapptz/RoboDanny
Expand All @@ -29,7 +32,7 @@ def quote(string: str) -> str:


# copied from https://github.com/openai/openai-python
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901
def required_args(*variants: Sequence[str]) -> t.Callable[[CallableT], CallableT]: # noqa: C901
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Expand Down Expand Up @@ -107,3 +110,162 @@ def wrapper(*args: object, **kwargs: object) -> object:
return wrapper # type: ignore

return inner


T = t.TypeVar("T")

E = t.TypeVar("E", bound=Exception)


@dataclass
class Result(t.Generic[T, E]):
value: T | None
error: E | None

@property
def ok(self) -> bool:
return self.error is None

@property
def has_value(self) -> bool:
return self.value is not None


class Address(t.NamedTuple):
protocol: str
host: str
port: int


@dataclass
class Kube:
user_token: str | None = None
from kubernetes import client, config

DEFAULT_NS = "kubeflow"
DSC_CRD = "datasciencecluster.opendatahub.io/v1"
DSC_NS_CONFIG = "registriesNamespace"
EXTERNAL_ADDR_ANNOTATION = "routing.opendatahub.io/external-address-rest"

def __post_init__(self):
self.config.load_incluster_config()
client = Kube.client.ApiClient()
self.sa_token = client.configuration.api_key["authorization"]
self.api_client = client

def __enter__(self) -> Kube:
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.api_client.close()

def try_get(
self, op: t.Callable[[], t.Any], as_user: bool = False
) -> Result[t.Any, client.ApiException]:
if as_user and self.user_token is not None:
# NOTE: even though this config is consumed by the RESTClient, auth is refreshed on every request: https://github.com/kubernetes-client/python/blob/b7ccf179f1b0194a0ed18e39fb063ef8a963fc6b/kubernetes/client/api_client.py#L166
self.api_client.configuration.api_key["authorization"] = self.user_token
try:
return Result(op(), None)
except Kube.client.ApiException as e:
if e.status != 403:
raise e
return Result(None, e)
finally:
self.api_client.configuration.api_key["authorization"] = self.sa_token

def try_get_with_any_token(
self, op: t.Callable[[], t.Any]
) -> Result[t.Any, client.ApiException]:
res = self.try_get(op)
if res.error is not None and self.user_token:
res = self.try_get(op, as_user=True)
return res

def get_default_dsc(self) -> Result[dict[str, t.Any], StoreError]:
kcustom = Kube.client.CustomObjectsApi(self.api_client)

g, v = Kube.DSC_CRD.split("/")
p = f"{g.split('.')[0]}s"

def list_dscs() -> t.Any:
return kcustom.list_cluster_custom_object(
group=g,
version=v,
plural=p,
)

res = self.try_get_with_any_token(list_dscs)
if dscs := res.value:
return Result(
t.cast(
dict[str, t.Any],
dscs["items"][0],
),
None,
)
return Result(None, StoreError(f"Failed to list {p}: {res.error}"))

def get_mr_ns(self) -> Result[str, StoreError]:
res = self.get_default_dsc()
if dsc_raw := res.value:
return Result(
dsc_raw["status"]["components"]["modelregistry"][Kube.DSC_NS_CONFIG],
None,
)
return Result(Kube.DEFAULT_NS, res.error)

def get_namespaced_service(
self, name: str, ns: str
) -> Result[client.V1Service, StoreError]:
kcore = self.client.CoreV1Api(self.api_client)

def get_service() -> t.Any:
return kcore.read_namespaced_service(name, ns)

res = self.try_get_with_any_token(get_service)
if serv := res.value:
return Result(t.cast(Kube.client.V1Service, serv), None)
return Result(None, StoreError(f"Failed to get service {name}: {res.error}"))

def get_service_addr(self, name: str, ns: str) -> Result[Address, StoreError]:
res = self.get_namespaced_service(name, ns)
if res.error:
return Result(None, res.error)

serv = res.value
assert serv is not None
meta = t.cast(Kube.client.V1ObjectMeta, serv.metadata)
ext_addr = t.cast(dict[str, str], meta.annotations).get(
Kube.EXTERNAL_ADDR_ANNOTATION
)
err = None
if not ext_addr:
host = str(meta.name)
port_by_protocol = {
port.app_protocol: port
for port in t.cast(
list[Kube.client.V1ServicePort],
t.cast(Kube.client.V1ServiceSpec, serv.spec).ports,
)
if port.app_protocol in ("http", "https")
}
if p := port_by_protocol.get("https"):
port = int(str(p.port))
protocol = "https"
elif p := port_by_protocol.get("http"):
port = int(str(p.port))
protocol = "http"
else:
err = StoreError(f"Service {name} has no http(s) ports")
port = 8080
protocol = "http"
else:
from urllib.parse import urlparse

parsed = urlparse(ext_addr)
protocol = parsed.scheme
host, port = parsed.netloc.split(":")
port = int(port)

return Result(Address(protocol, host, port), err)

0 comments on commit c01f397

Please sign in to comment.