Skip to content

Commit

Permalink
Adding .trussrc support (basetenlabs#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
aspctu authored Jul 14, 2023
1 parent 1fff764 commit 039fd89
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 24 deletions.
33 changes: 9 additions & 24 deletions truss/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import click
import truss
import yaml
from truss.remote.baseten import BasetenRemote
from truss.remote.remote_factory import RemoteFactory

logging.basicConfig(level=logging.INFO)

REGISTRY = {"baseten": (BasetenRemote, "https://app.baseten.co")}


def echo_output(f: Callable[..., object]):
@wraps(f)
Expand Down Expand Up @@ -297,43 +295,29 @@ def train(target_directory: str, build_dir, tag, var: List[str], vars_yaml_file,

@cli_group.command()
@click.argument("target_directory", required=False, default=os.getcwd())
@click.option("--api-key", type=str, required=False, help="Your API key")
@click.option("--model-name", type=str, required=False, help="Name of the model")
@click.option(
"--remote-name",
"--remote_name",
type=str,
required=False,
help="Name of the remote",
default="baseten",
required=True,
help="Name of the remote in .trussrc to push to",
)
@click.option("--remote-url", type=str, required=False, help="URL of the remote")
@click.option("--model-name", type=str, required=False, help="Name of the model")
@error_handling
def push(
target_directory: str,
api_key: str,
model_name: str,
remote_name: str,
remote_url: str,
model_name: str,
) -> None:
"""
Pushes a truss to a TrussRemote.
TARGET_DIRECTORY: A Truss directory. If none, use current directory.
"""
if remote_name not in REGISTRY:
raise ValueError(
f"Remote {remote_name} not found. Available remotes: {list(REGISTRY.keys())}"
)
remote = RemoteFactory.create(remote=remote_name)

tr = _get_truss_from_directory(target_directory=target_directory)

# Instantiate remote
# NOTE: This is specific to Baseten, but TODO: generalize
remote_cls = REGISTRY[remote_name][0]
remote_url = remote_url or REGISTRY[remote_name][1]
remote = remote_cls(remote_url, api_key)

# Push
model_name = model_name or tr.spec.config.model_name
if model_name is None:
Expand All @@ -346,7 +330,8 @@ def push(
tr.spec.config.model_name = model_name
tr.spec.config.write_to_yaml_file(tr.spec.config_path)

service = remote.push(tr, model_name)
# TODO(Abu): This needs to be refactored to be more generic
service = remote.push(tr, model_name) # type: ignore

click.echo(f"Model {model_name} was successfully pushed.")
click.echo(f"Service URL: {service._service_url}")
Expand Down
92 changes: 92 additions & 0 deletions truss/remote/remote_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import configparser
import inspect
from pathlib import Path
from typing import Dict, Type

from truss.remote.baseten import BasetenRemote
from truss.remote.truss_remote import TrussRemote


class RemoteFactory:
"""
A factory for instantiating a TrussRemote from a .trussrc file and a user-specified remote config name
"""

REGISTRY: Dict[str, Type[TrussRemote]] = {"baseten": BasetenRemote}

@staticmethod
def load_remote_config(remote_name: str) -> Dict:
"""
Load and validate a remote config from the .trussrc file
"""
config_path = Path("~/.trussrc").expanduser()

if not config_path.exists():
raise FileNotFoundError(f"No .trussrc file found at {config_path}")

config = configparser.ConfigParser()
config.read(config_path)

if remote_name not in config:
raise ValueError(f"Service provider {remote_name} not found in .trussrc")

return dict(config[remote_name])

@staticmethod
def validate_remote_config(remote_config: Dict, remote_name: str):
"""
Validates remote config by checking
1. the 'remote' field exists
2. all required parameters for the 'remote' class are provided
"""
if "remote_provider" not in remote_config:
raise ValueError(
f"Missing 'remote_provider' field for remote {remote_name} in .trussrc"
)

if remote_config["remote_provider"] not in RemoteFactory.REGISTRY:
raise ValueError(
f"Remote provider {remote_config['remote_provider']} not found in registry"
)

remote = RemoteFactory.REGISTRY.get(remote_config["remote_provider"])
if remote:
required_params = RemoteFactory.required_params(remote)
missing_params = required_params - set(remote_config.keys())
if missing_params:
raise ValueError(
f"Missing required parameter(s) {missing_params} for remote {remote_name} in .trussrc"
)

@staticmethod
def required_params(remote: Type[TrussRemote]) -> set:
"""
Get the required parameters for a remote by inspecting its __init__ method
"""
init_signature = inspect.signature(remote.__init__)
params = init_signature.parameters
required_params = {
name
for name, param in params.items()
if param.default == inspect.Parameter.empty
and name not in {"self", "args", "kwargs"}
}
return required_params

@classmethod
def create(cls, remote: str) -> TrussRemote:
remote_config = cls.load_remote_config(remote)
cls.validate_remote_config(remote_config, remote)

remote_class = cls.REGISTRY[remote_config.pop("remote_provider")]
remote_params = {
param: remote_config.get(param)
for param in cls.required_params(remote_class)
}

# Add any additional params provided by the user in their .trussrc
additional_params = set(remote_config.keys()) - set(remote_params.keys())
for param in additional_params:
remote_params[param] = remote_config.get(param)

return remote_class(**remote_params) # type: ignore
120 changes: 120 additions & 0 deletions truss/tests/remote/test_remote_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from unittest import mock

import pytest
from truss.remote.remote_factory import RemoteFactory
from truss.remote.truss_remote import TrussRemote

SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"}

SAMPLE_TRUSSRC = """
[test]
remote_provider=test_remote
api_key=test_key
remote_url=http://test.com
"""

SAMPLE_TRUSSRC_NO_REMOTE = """
[test]
api_key=test_key
remote_url=http://test.com
"""

SAMPLE_TRUSSRC_NO_PARAMS = """
[test]
remote_provider=test_remote
"""


class TestRemote(TrussRemote):
def __init__(self, api_key, remote_url):
self.api_key = api_key
self.remote_url = remote_url

def authenticate(self):
return {"Authorization": self.api_key}

def push(self):
return {"status": "success"}


def mock_service_config():
return {"remote_provider": "test_remote", **SAMPLE_CONFIG}


def mock_incorrect_service_config():
return {"remote_provider": "nonexistent_remote", **SAMPLE_CONFIG}


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch(
"truss.remote.remote_factory.RemoteFactory.load_remote_config",
return_value=mock_service_config(),
)
def test_create(mock_load_remote_config):
service_name = "test_service"
remote = RemoteFactory.create(service_name)
mock_load_remote_config.assert_called_once_with(service_name)
assert isinstance(remote, TestRemote)


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch(
"truss.remote.remote_factory.RemoteFactory.load_remote_config",
return_value=mock_incorrect_service_config(),
)
def test_create_no_service(mock_load_remote_config):
service_name = "nonexistent_service"
with pytest.raises(ValueError):
RemoteFactory.create(service_name)


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_load_remote_config(mock_exists, mock_open):
service = RemoteFactory.load_remote_config("test")
assert service == {"remote_provider": "test_remote", **SAMPLE_CONFIG}


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
@mock.patch("pathlib.Path.exists", return_value=False)
def test_load_remote_config_no_file(mock_exists, mock_open):
with pytest.raises(FileNotFoundError):
RemoteFactory.load_remote_config("test")


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_load_remote_config_no_service(mock_exists, mock_open):
with pytest.raises(ValueError):
RemoteFactory.load_remote_config("nonexistent_service")


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
def test_required_params():
required_params = RemoteFactory.required_params(TestRemote)
assert required_params == {"api_key", "remote_url"}


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch(
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_REMOTE
)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_validate_remote_config_no_remote(mock_exists, mock_open):
with pytest.raises(ValueError):
service = RemoteFactory.load_remote_config("test")
RemoteFactory.validate_remote_config(service, "test")


@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
@mock.patch(
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_PARAMS
)
@mock.patch("pathlib.Path.exists", return_value=True)
def test_load_remote_config_no_params(mock_exists, mock_open):
with pytest.raises(ValueError):
service = RemoteFactory.load_remote_config("test")
RemoteFactory.validate_remote_config(service, "test")

0 comments on commit 039fd89

Please sign in to comment.