Skip to content

Commit

Permalink
py: verify overload signatures at runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Mar 19, 2024
1 parent 5f6e7e9 commit 354f12d
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
109 changes: 109 additions & 0 deletions clients/python/src/model_registry/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import functools
import inspect
from collections.abc import Sequence
from typing import Any, Callable, TypeVar

CallableT = TypeVar("CallableT", bound=Callable[..., Any])


# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""

if size == 1:
return seq[0]

if size == 2:
return f"{seq[0]} {final} {seq[1]}"

return delim.join(seq[:-1]) + f" {final} {seq[-1]}"


def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"


# copied from https://github.com/openai/openai-python
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str:
...
@overload
def foo(*, b: bool) -> str:
...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str:
...
```
"""

def inner(func: CallableT) -> CallableT: # noqa: C901
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]

@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
msg = f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
raise TypeError(msg) from None

for key in kwargs:
given_params.add(key)

for variant in variants:
matches = all(param in given_params for param in variant)
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
[
"("
+ human_join([quote(arg) for arg in variant], final="and")
+ ")"
for variant in variants
]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)

return wrapper # type: ignore

return inner
12 changes: 12 additions & 0 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing_extensions import overload

from ._utils import required_args
from .exceptions import StoreException


Expand All @@ -32,6 +33,17 @@ def s3_uri_from(
) -> str: ...


@required_args(
(),
( # pre-configured env
"bucket",
),
( # custom env or non-default bucket
"bucket",
"endpoint",
"region",
),
)
def s3_uri_from(
path: str,
bucket: str | None = None,
Expand Down

0 comments on commit 354f12d

Please sign in to comment.