Skip to content

Commit

Permalink
feat(vllm): Allow for API key and corresponding usage plan when API i…
Browse files Browse the repository at this point in the history
…s not public
  • Loading branch information
bramelfrink committed Oct 8, 2024
1 parent dd1f86b commit b06e422
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 27 deletions.
105 changes: 87 additions & 18 deletions src/damavand/cloud/aws/resources/vllm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,23 @@ def __init__(

_ = self.api
_ = self.api_resource_v1
_ = self.api_resource_v1
_ = self.api_resource_chat
_ = self.api_resource_completions

# Only create API key if public internet access is set to False
if not self.args.public_internet_access:
_ = self.api_authorizer
print(">>> Hello there: no public internet access so creating API key etc.")
_ = self.admin_api_key
_ = self.default_usage_plan
_ = self.api_key_usage_plan
_ = self.api_key_secret
_ = self.api_key_secret_version

_ = self.api_method
_ = self.api_integration
_ = self.api_integration_response
_ = self.api_method_response
_ = self.api_deploy
_ = self.api_deployment

def get_service_assume_policy(self, service: str) -> dict[str, Any]:
"""Return the assume role policy for the requested service.
Expand Down Expand Up @@ -309,11 +315,6 @@ def api_resource_chat(self) -> aws.apigateway.Resource:
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_resource`is only available when public_internet_access is True"
)

return aws.apigateway.Resource(
resource_name=f"{self._name}-api-resource-chat",
opts=ResourceOptions(parent=self),
Expand All @@ -334,11 +335,6 @@ def api_resource_completions(self) -> aws.apigateway.Resource:
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_resource`is only available when public_internet_access is True"
)

return aws.apigateway.Resource(
resource_name=f"{self._name}-api-resource-completions",
opts=ResourceOptions(parent=self),
Expand All @@ -361,7 +357,7 @@ def api_method(self) -> aws.apigateway.Method:
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
resource_id=self.api_resource_completions.id,
http_method="POST",
authorization="NONE",
)
Expand All @@ -370,12 +366,85 @@ def api_method(self) -> aws.apigateway.Method:
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
resource_id=self.api_resource_completions.id,
http_method="POST",
authorization="COGNITO_USER_POOLS",
authorizer_id=self.api_authorizer.id,
authorization="NONE",
api_key_required=True,
)

@property
@cache
def admin_api_key(self) -> aws.apigateway.ApiKey:
"""
Return the admin API key for the API Gateway
"""
return aws.apigateway.ApiKey(
resource_name=f"{self._name}-api-key",
opts=ResourceOptions(parent=self),
)

@property
@cache
def api_key_secret(self) -> aws.secretsmanager.Secret:
"""
Return the secret for the API key
"""

return aws.secretsmanager.Secret(
resource_name=f"{self._name}-api-key-secret",
opts=ResourceOptions(parent=self),
)

@property
@cache
def api_key_secret_version(self) -> aws.secretsmanager.SecretVersion:
"""
Return the secret version for the API key
"""

return aws.secretsmanager.SecretVersion(
resource_name=f"{self._name}-api-key-secret-version",
opts=ResourceOptions(parent=self, depends_on=[self.api_key_secret]),
secret_id=self.api_key_secret.id,
secret_string=self.admin_api_key.id,
)


@property
@cache
def default_usage_plan(self) -> aws.apigateway.UsagePlan:
"""
Return a default usage plan for the API Gateway, that does not limit the usage.
"""

return aws.apigateway.UsagePlan(
resource_name=f"{self._name}-api-usage-plan",
opts=ResourceOptions(parent=self),
api_stages=[
aws.apigateway.UsagePlanApiStageArgs(
api_id=self.api.id,
# NOTE: How do we want to deal with API stages vs. AWS environments?
stage=self.args.api_env_name,
)
],
)


@property
@cache
def api_key_usage_plan(self) -> aws.apigateway.UsagePlanKey:
"""
Return the usage plan key for the API Gateway
"""

return aws.apigateway.UsagePlanKey(
resource_name=f"{self._name}-api-usage-plan-key",
opts=ResourceOptions(parent=self),
key_id=self.admin_api_key.id,
key_type="API_KEY",
usage_plan_id=self.default_usage_plan.id,
)

@property
def api_sagemaker_integration_uri(self) -> pulumi.Output[str]:
"""
Expand Down Expand Up @@ -462,7 +531,7 @@ def api_method_response(self) -> aws.apigateway.MethodResponse:
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource_completions.id,
http_method="POST",
http_method=self.api_method.http_method,
status_code="200",
)

Expand Down
30 changes: 21 additions & 9 deletions tests/clouds/aws/resources/test_vllm_component.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from typing import Optional, Tuple, List

import pulumi
Expand Down Expand Up @@ -32,15 +33,19 @@ def test_private_internet_access():
args=AwsVllmComponentArgs(),
)

with pytest.raises(AttributeError):
vllm.api
vllm.api_resource_completions
vllm.api_method
vllm.api_access_sagemaker_role
vllm.api_integration
vllm.api_integration_response
vllm.api_method_response
vllm.api_deployment
assert isinstance(vllm.api, aws.apigateway.RestApi)
assert isinstance(vllm.api_resource_completions, aws.apigateway.Resource)
assert isinstance(vllm.api_method, aws.apigateway.Method)
assert isinstance(vllm.api_access_sagemaker_role, aws.iam.Role)
assert isinstance(vllm.api_integration, aws.apigateway.Integration)
assert isinstance(vllm.api_integration_response, aws.apigateway.IntegrationResponse)
assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse)
assert isinstance(vllm.api_deployment, aws.apigateway.Deployment)
assert isinstance(vllm.admin_api_key, aws.apigateway.ApiKey)
assert isinstance(vllm.default_usage_plan, aws.apigateway.UsagePlan)
assert isinstance(vllm.api_key_usage_plan, aws.apigateway.UsagePlanKey)
assert isinstance(vllm.api_key_secret, aws.secretsmanager.Secret)
assert isinstance(vllm.api_key_secret_version, aws.secretsmanager.SecretVersion)


def test_public_internet_access():
Expand All @@ -60,6 +65,13 @@ def test_public_internet_access():
assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse)
assert isinstance(vllm.api_deployment, aws.apigateway.Deployment)

with pytest.raises(AttributeError):
vllm.admin_api_key
vllm.default_usage_plan
vllm.api_key_usage_plan
vllm.api_key_secret
vllm.api_key_secret_version


def test_model_image_version():
vllm = AwsVllmComponent(
Expand Down

0 comments on commit b06e422

Please sign in to comment.