forked from basetenlabs/truss
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding .trussrc support (basetenlabs#436)
- Loading branch information
Showing
3 changed files
with
221 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import configparser | ||
import inspect | ||
from pathlib import Path | ||
from typing import Dict, Type | ||
|
||
from truss.remote.baseten import BasetenRemote | ||
from truss.remote.truss_remote import TrussRemote | ||
|
||
|
||
class RemoteFactory: | ||
""" | ||
A factory for instantiating a TrussRemote from a .trussrc file and a user-specified remote config name | ||
""" | ||
|
||
REGISTRY: Dict[str, Type[TrussRemote]] = {"baseten": BasetenRemote} | ||
|
||
@staticmethod | ||
def load_remote_config(remote_name: str) -> Dict: | ||
""" | ||
Load and validate a remote config from the .trussrc file | ||
""" | ||
config_path = Path("~/.trussrc").expanduser() | ||
|
||
if not config_path.exists(): | ||
raise FileNotFoundError(f"No .trussrc file found at {config_path}") | ||
|
||
config = configparser.ConfigParser() | ||
config.read(config_path) | ||
|
||
if remote_name not in config: | ||
raise ValueError(f"Service provider {remote_name} not found in .trussrc") | ||
|
||
return dict(config[remote_name]) | ||
|
||
@staticmethod | ||
def validate_remote_config(remote_config: Dict, remote_name: str): | ||
""" | ||
Validates remote config by checking | ||
1. the 'remote' field exists | ||
2. all required parameters for the 'remote' class are provided | ||
""" | ||
if "remote_provider" not in remote_config: | ||
raise ValueError( | ||
f"Missing 'remote_provider' field for remote {remote_name} in .trussrc" | ||
) | ||
|
||
if remote_config["remote_provider"] not in RemoteFactory.REGISTRY: | ||
raise ValueError( | ||
f"Remote provider {remote_config['remote_provider']} not found in registry" | ||
) | ||
|
||
remote = RemoteFactory.REGISTRY.get(remote_config["remote_provider"]) | ||
if remote: | ||
required_params = RemoteFactory.required_params(remote) | ||
missing_params = required_params - set(remote_config.keys()) | ||
if missing_params: | ||
raise ValueError( | ||
f"Missing required parameter(s) {missing_params} for remote {remote_name} in .trussrc" | ||
) | ||
|
||
@staticmethod | ||
def required_params(remote: Type[TrussRemote]) -> set: | ||
""" | ||
Get the required parameters for a remote by inspecting its __init__ method | ||
""" | ||
init_signature = inspect.signature(remote.__init__) | ||
params = init_signature.parameters | ||
required_params = { | ||
name | ||
for name, param in params.items() | ||
if param.default == inspect.Parameter.empty | ||
and name not in {"self", "args", "kwargs"} | ||
} | ||
return required_params | ||
|
||
@classmethod | ||
def create(cls, remote: str) -> TrussRemote: | ||
remote_config = cls.load_remote_config(remote) | ||
cls.validate_remote_config(remote_config, remote) | ||
|
||
remote_class = cls.REGISTRY[remote_config.pop("remote_provider")] | ||
remote_params = { | ||
param: remote_config.get(param) | ||
for param in cls.required_params(remote_class) | ||
} | ||
|
||
# Add any additional params provided by the user in their .trussrc | ||
additional_params = set(remote_config.keys()) - set(remote_params.keys()) | ||
for param in additional_params: | ||
remote_params[param] = remote_config.get(param) | ||
|
||
return remote_class(**remote_params) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from truss.remote.remote_factory import RemoteFactory | ||
from truss.remote.truss_remote import TrussRemote | ||
|
||
SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"} | ||
|
||
SAMPLE_TRUSSRC = """ | ||
[test] | ||
remote_provider=test_remote | ||
api_key=test_key | ||
remote_url=http://test.com | ||
""" | ||
|
||
SAMPLE_TRUSSRC_NO_REMOTE = """ | ||
[test] | ||
api_key=test_key | ||
remote_url=http://test.com | ||
""" | ||
|
||
SAMPLE_TRUSSRC_NO_PARAMS = """ | ||
[test] | ||
remote_provider=test_remote | ||
""" | ||
|
||
|
||
class TestRemote(TrussRemote): | ||
def __init__(self, api_key, remote_url): | ||
self.api_key = api_key | ||
self.remote_url = remote_url | ||
|
||
def authenticate(self): | ||
return {"Authorization": self.api_key} | ||
|
||
def push(self): | ||
return {"status": "success"} | ||
|
||
|
||
def mock_service_config(): | ||
return {"remote_provider": "test_remote", **SAMPLE_CONFIG} | ||
|
||
|
||
def mock_incorrect_service_config(): | ||
return {"remote_provider": "nonexistent_remote", **SAMPLE_CONFIG} | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch( | ||
"truss.remote.remote_factory.RemoteFactory.load_remote_config", | ||
return_value=mock_service_config(), | ||
) | ||
def test_create(mock_load_remote_config): | ||
service_name = "test_service" | ||
remote = RemoteFactory.create(service_name) | ||
mock_load_remote_config.assert_called_once_with(service_name) | ||
assert isinstance(remote, TestRemote) | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch( | ||
"truss.remote.remote_factory.RemoteFactory.load_remote_config", | ||
return_value=mock_incorrect_service_config(), | ||
) | ||
def test_create_no_service(mock_load_remote_config): | ||
service_name = "nonexistent_service" | ||
with pytest.raises(ValueError): | ||
RemoteFactory.create(service_name) | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC) | ||
@mock.patch("pathlib.Path.exists", return_value=True) | ||
def test_load_remote_config(mock_exists, mock_open): | ||
service = RemoteFactory.load_remote_config("test") | ||
assert service == {"remote_provider": "test_remote", **SAMPLE_CONFIG} | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC) | ||
@mock.patch("pathlib.Path.exists", return_value=False) | ||
def test_load_remote_config_no_file(mock_exists, mock_open): | ||
with pytest.raises(FileNotFoundError): | ||
RemoteFactory.load_remote_config("test") | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC) | ||
@mock.patch("pathlib.Path.exists", return_value=True) | ||
def test_load_remote_config_no_service(mock_exists, mock_open): | ||
with pytest.raises(ValueError): | ||
RemoteFactory.load_remote_config("nonexistent_service") | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
def test_required_params(): | ||
required_params = RemoteFactory.required_params(TestRemote) | ||
assert required_params == {"api_key", "remote_url"} | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch( | ||
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_REMOTE | ||
) | ||
@mock.patch("pathlib.Path.exists", return_value=True) | ||
def test_validate_remote_config_no_remote(mock_exists, mock_open): | ||
with pytest.raises(ValueError): | ||
service = RemoteFactory.load_remote_config("test") | ||
RemoteFactory.validate_remote_config(service, "test") | ||
|
||
|
||
@mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True) | ||
@mock.patch( | ||
"builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_PARAMS | ||
) | ||
@mock.patch("pathlib.Path.exists", return_value=True) | ||
def test_load_remote_config_no_params(mock_exists, mock_open): | ||
with pytest.raises(ValueError): | ||
service = RemoteFactory.load_remote_config("test") | ||
RemoteFactory.validate_remote_config(service, "test") |