From b06e422d4808fd1c7d99eb2e08342ea57c111449 Mon Sep 17 00:00:00 2001 From: Bram Elfrink Date: Tue, 8 Oct 2024 12:33:16 +0200 Subject: [PATCH] feat(vllm): Allow for API key and corresponding usage plan when API is not public --- .../cloud/aws/resources/vllm_component.py | 105 +++++++++++++++--- .../aws/resources/test_vllm_component.py | 30 +++-- 2 files changed, 108 insertions(+), 27 deletions(-) diff --git a/src/damavand/cloud/aws/resources/vllm_component.py b/src/damavand/cloud/aws/resources/vllm_component.py index bf672a4..03d1fb2 100644 --- a/src/damavand/cloud/aws/resources/vllm_component.py +++ b/src/damavand/cloud/aws/resources/vllm_component.py @@ -133,17 +133,23 @@ def __init__( _ = self.api _ = self.api_resource_v1 - _ = self.api_resource_v1 + _ = self.api_resource_chat _ = self.api_resource_completions + # Only create API key if public internet access is set to False if not self.args.public_internet_access: - _ = self.api_authorizer + print(">>> Hello there: no public internet access so creating API key etc.") + _ = self.admin_api_key + _ = self.default_usage_plan + _ = self.api_key_usage_plan + _ = self.api_key_secret + _ = self.api_key_secret_version _ = self.api_method _ = self.api_integration _ = self.api_integration_response _ = self.api_method_response - _ = self.api_deploy + _ = self.api_deployment def get_service_assume_policy(self, service: str) -> dict[str, Any]: """Return the assume role policy for the requested service. @@ -309,11 +315,6 @@ def api_resource_chat(self) -> aws.apigateway.Resource: When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource-chat", opts=ResourceOptions(parent=self), @@ -334,11 +335,6 @@ def api_resource_completions(self) -> aws.apigateway.Resource: When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource-completions", opts=ResourceOptions(parent=self), @@ -361,7 +357,7 @@ def api_method(self) -> aws.apigateway.Method: resource_name=f"{self._name}-api-method", opts=ResourceOptions(parent=self), rest_api=self.api.id, - resource_id=self.api_resource.id, + resource_id=self.api_resource_completions.id, http_method="POST", authorization="NONE", ) @@ -370,12 +366,85 @@ def api_method(self) -> aws.apigateway.Method: resource_name=f"{self._name}-api-method", opts=ResourceOptions(parent=self), rest_api=self.api.id, - resource_id=self.api_resource.id, + resource_id=self.api_resource_completions.id, http_method="POST", - authorization="COGNITO_USER_POOLS", - authorizer_id=self.api_authorizer.id, + authorization="NONE", + api_key_required=True, ) + @property + @cache + def admin_api_key(self) -> aws.apigateway.ApiKey: + """ + Return the admin API key for the API Gateway + """ + return aws.apigateway.ApiKey( + resource_name=f"{self._name}-api-key", + opts=ResourceOptions(parent=self), + ) + + @property + @cache + def api_key_secret(self) -> aws.secretsmanager.Secret: + """ + Return the secret for the API key + """ + + return aws.secretsmanager.Secret( + resource_name=f"{self._name}-api-key-secret", + opts=ResourceOptions(parent=self), + ) + + @property + @cache + def api_key_secret_version(self) -> aws.secretsmanager.SecretVersion: + """ + Return the secret version for the API key + """ + + return aws.secretsmanager.SecretVersion( + resource_name=f"{self._name}-api-key-secret-version", + opts=ResourceOptions(parent=self, depends_on=[self.api_key_secret]), + secret_id=self.api_key_secret.id, + secret_string=self.admin_api_key.id, + ) + + + @property + @cache + def default_usage_plan(self) -> aws.apigateway.UsagePlan: + """ + Return a default usage plan for the API Gateway, that does not limit the usage. + """ + + return aws.apigateway.UsagePlan( + resource_name=f"{self._name}-api-usage-plan", + opts=ResourceOptions(parent=self), + api_stages=[ + aws.apigateway.UsagePlanApiStageArgs( + api_id=self.api.id, + # NOTE: How do we want to deal with API stages vs. AWS environments? + stage=self.args.api_env_name, + ) + ], + ) + + + @property + @cache + def api_key_usage_plan(self) -> aws.apigateway.UsagePlanKey: + """ + Return the usage plan key for the API Gateway + """ + + return aws.apigateway.UsagePlanKey( + resource_name=f"{self._name}-api-usage-plan-key", + opts=ResourceOptions(parent=self), + key_id=self.admin_api_key.id, + key_type="API_KEY", + usage_plan_id=self.default_usage_plan.id, + ) + @property def api_sagemaker_integration_uri(self) -> pulumi.Output[str]: """ @@ -462,7 +531,7 @@ def api_method_response(self) -> aws.apigateway.MethodResponse: opts=ResourceOptions(parent=self), rest_api=self.api.id, resource_id=self.api_resource_completions.id, - http_method="POST", + http_method=self.api_method.http_method, status_code="200", ) diff --git a/tests/clouds/aws/resources/test_vllm_component.py b/tests/clouds/aws/resources/test_vllm_component.py index 0f6150e..681e266 100644 --- a/tests/clouds/aws/resources/test_vllm_component.py +++ b/tests/clouds/aws/resources/test_vllm_component.py @@ -1,3 +1,4 @@ +import pytest from typing import Optional, Tuple, List import pulumi @@ -32,15 +33,19 @@ def test_private_internet_access(): args=AwsVllmComponentArgs(), ) - with pytest.raises(AttributeError): - vllm.api - vllm.api_resource_completions - vllm.api_method - vllm.api_access_sagemaker_role - vllm.api_integration - vllm.api_integration_response - vllm.api_method_response - vllm.api_deployment + assert isinstance(vllm.api, aws.apigateway.RestApi) + assert isinstance(vllm.api_resource_completions, aws.apigateway.Resource) + assert isinstance(vllm.api_method, aws.apigateway.Method) + assert isinstance(vllm.api_access_sagemaker_role, aws.iam.Role) + assert isinstance(vllm.api_integration, aws.apigateway.Integration) + assert isinstance(vllm.api_integration_response, aws.apigateway.IntegrationResponse) + assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse) + assert isinstance(vllm.api_deployment, aws.apigateway.Deployment) + assert isinstance(vllm.admin_api_key, aws.apigateway.ApiKey) + assert isinstance(vllm.default_usage_plan, aws.apigateway.UsagePlan) + assert isinstance(vllm.api_key_usage_plan, aws.apigateway.UsagePlanKey) + assert isinstance(vllm.api_key_secret, aws.secretsmanager.Secret) + assert isinstance(vllm.api_key_secret_version, aws.secretsmanager.SecretVersion) def test_public_internet_access(): @@ -60,6 +65,13 @@ def test_public_internet_access(): assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse) assert isinstance(vllm.api_deployment, aws.apigateway.Deployment) + with pytest.raises(AttributeError): + vllm.admin_api_key + vllm.default_usage_plan + vllm.api_key_usage_plan + vllm.api_key_secret + vllm.api_key_secret_version + def test_model_image_version(): vllm = AwsVllmComponent(