diff --git a/src/damavand/cloud/aws/controllers/llm.py b/src/damavand/cloud/aws/controllers/llm.py index cc08b2a..2dbd12a 100644 --- a/src/damavand/cloud/aws/controllers/llm.py +++ b/src/damavand/cloud/aws/controllers/llm.py @@ -108,7 +108,7 @@ def resource(self) -> PulumiResource: name=self.name, args=AwsVllmComponentArgs( region=self._region, - public_internet_access=True, + api_key_required=True, endpoint_ssm_parameter_name=self._base_url_ssm_name, ), ) diff --git a/src/damavand/cloud/aws/resources/vllm_component.py b/src/damavand/cloud/aws/resources/vllm_component.py index 78138c7..05feeb2 100644 --- a/src/damavand/cloud/aws/resources/vllm_component.py +++ b/src/damavand/cloud/aws/resources/vllm_component.py @@ -29,8 +29,8 @@ class AwsVllmComponentArgs: number of instances to deploy the model. instance_type : str type of instance to deploy the model. - public_internet_access : bool - whether to deploy a public API for the model. + api_key_required : bool + whether an API key is required for interacting with the API. api_env_name : str the name of the API environment. endpoint_ssm_parameter_name : str @@ -42,7 +42,7 @@ class AwsVllmComponentArgs: model_name: str = "microsoft/Phi-3-mini-4k-instruct" instance_initial_count: int = 1 instance_type: str = "ml.g4dn.xlarge" - public_internet_access: bool = False + api_key_required: bool = True api_env_name: str = "prod" endpoint_ssm_parameter_name: str = "/Vllm/endpoint/url" @@ -90,6 +90,22 @@ class AwsVllmComponent(PulumiComponentResource): Return a resource for completions routing. api_method() Return openai chat completions compatible method. + admin_api_key() + Return an admin API key for the API Gateway. + api_key_secret() + Return the Secret Manager secret for storing the API key. + api_key_secret_version() + Return the Secret Manager secret version (value) for storing the API key. + default_usage_plan() + Return a default usage plan for the API Gateway. + tier_1_usage_plan() + Return a tier 1 usage plan for the API Gateway. + tier_2_usage_plan() + Return a tier 2 usage plan for the API Gateway. + tier_3_usage_plan() + Return a tier 3 usage plan for the API Gateway. + api_key_usage_plan() + Return the UsagePlanKey where the default usage plan is associated with the API and admin API key. api_sagemaker_integration_uri() Return the SageMaker model integration URI for the API Gateway. apigateway_access_policies() @@ -125,19 +141,32 @@ def __init__( ) self.args = args + _ = self.model _ = self.endpoint_config _ = self.endpoint - if self.args.public_internet_access: - _ = self.api - _ = self.api_resource_completions - _ = self.api_method - _ = self.api_integration - _ = self.api_integration_response - _ = self.api_method_response - _ = self.api_deployment - _ = self.endpoint_ssm_parameter + _ = self.api + _ = self.api_resource_v1 + _ = self.api_resource_chat + _ = self.api_resource_completions + + # Only create API key if public internet access is set to False + if self.args.api_key_required: + _ = self.admin_api_key + _ = self.default_usage_plan + _ = self.tier_1_usage_plan + _ = self.tier_2_usage_plan + _ = self.tier_3_usage_plan + _ = self.api_key_usage_plan + _ = self.api_key_secret + _ = self.api_key_secret_version + + _ = self.api_method + _ = self.api_integration + _ = self.api_integration_response + _ = self.api_method_response + _ = self.api_deployment def get_service_assume_policy(self, service: str) -> dict[str, Any]: """Return the assume role policy for the requested service. @@ -264,18 +293,8 @@ def endpoint(self) -> aws.sagemaker.Endpoint: def api(self) -> aws.apigateway.RestApi: """ Return a public API for the SageMaker endpoint. - - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api` is only available when public_internet_access is True" - ) - return aws.apigateway.RestApi( resource_name=f"{self._name}-api", opts=ResourceOptions(parent=self), @@ -290,17 +309,8 @@ def api_resource_v1(self) -> aws.apigateway.Resource: """ Return a resource for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource-v1", opts=ResourceOptions(parent=self), @@ -318,14 +328,9 @@ def api_resource_chat(self) -> aws.apigateway.Resource: Raises ------ AttributeError - When public_internet_access is False. + When api_key_required is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource-chat", opts=ResourceOptions(parent=self), @@ -339,18 +344,8 @@ def api_resource_chat(self) -> aws.apigateway.Resource: def api_resource_completions(self) -> aws.apigateway.Resource: """ Return a resource for the API Gateway. - - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource-completions", opts=ResourceOptions(parent=self), @@ -359,23 +354,14 @@ def api_resource_completions(self) -> aws.apigateway.Resource: path_part="completions", ) + @property @cache def api_method(self) -> aws.apigateway.Method: """ Return a method for the API Gateway. - - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_method`is only available when public_internet_access is True" - ) - return aws.apigateway.Method( resource_name=f"{self._name}-api-method", opts=ResourceOptions(parent=self), @@ -383,17 +369,211 @@ def api_method(self) -> aws.apigateway.Method: resource_id=self.api_resource_completions.id, http_method="POST", authorization="NONE", + api_key_required=self.args.api_key_required, ) @property - def api_sagemaker_integration_uri(self) -> pulumi.Output[str]: + @cache + def admin_api_key(self) -> aws.apigateway.ApiKey: """ - Return the SageMaker model integration URI for the API Gateway + Return the admin API key for the API Gateway + Raises ------ AttributeError - When public_internet_access is False. + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`admin_api_key` is only available when api_key_required is False") + + return aws.apigateway.ApiKey( + resource_name=f"{self._name}-api-key", + opts=ResourceOptions(parent=self), + ) + + @property + @cache + def api_key_secret(self) -> aws.secretsmanager.Secret: + """ + Return the secret for the API key + + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`admin_api_secret` is only available when api_key_required is False") + + return aws.secretsmanager.Secret( + resource_name=f"{self._name}-api-key-secret", + opts=ResourceOptions(parent=self), + ) + + @property + @cache + def api_key_secret_version(self) -> aws.secretsmanager.SecretVersion: + """ + Return the secret version for the API key + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`api_key_secret_version` is only available when api_key_required is False") + + return aws.secretsmanager.SecretVersion( + resource_name=f"{self._name}-api-key-secret-version", + opts=ResourceOptions(parent=self, depends_on=[self.api_key_secret]), + secret_id=self.api_key_secret.id, + secret_string=self.admin_api_key.id, + ) + + + @property + @cache + def default_usage_plan(self) -> aws.apigateway.UsagePlan: + """ + Return a default usage plan for the API Gateway, that does not limit the usage. + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`default_usage_plan` is only available when api_key_required is False") + + return aws.apigateway.UsagePlan( + resource_name=f"{self._name}-default-api-usage-plan", + opts=ResourceOptions(parent=self), + api_stages=[ + aws.apigateway.UsagePlanApiStageArgs( + api_id=self.api.id, + # NOTE: How do we want to deal with API stages vs. AWS environments? + stage=self.args.api_env_name, + ) + ], + ) + + @property + @cache + def tier_1_usage_plan(self) -> aws.apigateway.UsagePlan: + """ + Return a tier 1 usage plan for the API Gateway, with the following limits: + - requests per minute: 500 + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`default_usage_plan` is only available when api_key_required is False") + + return aws.apigateway.UsagePlan( + resource_name=f"{self._name}-tier-1-api-usage-plan", + opts=ResourceOptions(parent=self), + api_stages=[ + aws.apigateway.UsagePlanApiStageArgs( + api_id=self.api.id, + stage=self.args.api_env_name, + ) + ], + throttle_settings=aws.apigateway.UsagePlanThrottleSettingsArgs( + rate_limit=500 + ) + ) + + @property + @cache + def tier_2_usage_plan(self) -> aws.apigateway.UsagePlan: + """ + Return a tier 2 usage plan for the API Gateway, with the following limits: + - requests per minute: 5000 + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`default_usage_plan` is only available when api_key_required is False") + + return aws.apigateway.UsagePlan( + resource_name=f"{self._name}-tier-2-api-usage-plan", + opts=ResourceOptions(parent=self), + api_stages=[ + aws.apigateway.UsagePlanApiStageArgs( + api_id=self.api.id, + stage=self.args.api_env_name, + ) + ], + throttle_settings=aws.apigateway.UsagePlanThrottleSettingsArgs( + rate_limit=5000 + ) + ) + + @property + @cache + def tier_3_usage_plan(self) -> aws.apigateway.UsagePlan: + """ + Return a tier 3 usage plan for the API Gateway, with the following limits: + - requests per minute: 10000 + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`default_usage_plan` is only available when api_key_required is False") + + return aws.apigateway.UsagePlan( + resource_name=f"{self._name}-tier-2-api-usage-plan", + opts=ResourceOptions(parent=self), + api_stages=[ + aws.apigateway.UsagePlanApiStageArgs( + api_id=self.api.id, + stage=self.args.api_env_name, + ) + ], + throttle_settings=aws.apigateway.UsagePlanThrottleSettingsArgs( + rate_limit=10000 + ) + ) + + @property + @cache + def api_key_usage_plan(self) -> aws.apigateway.UsagePlanKey: + """ + Return the usage plan key for the API Gateway + + Raises + ------ + AttributeError + When api_key_required is False. + """ + if not self.args.api_key_required: + raise AttributeError("`api_key_usage_plan` is only available when api_key_required is False") + + return aws.apigateway.UsagePlanKey( + resource_name=f"{self._name}-api-usage-plan-key", + opts=ResourceOptions(parent=self), + key_id=self.admin_api_key.id, + key_type="API_KEY", + usage_plan_id=self.default_usage_plan.id, + ) + + @property + def api_sagemaker_integration_uri(self) -> pulumi.Output[str]: + """ + Return the SageMaker model integration URI for the API Gateway + """ return self.endpoint.name.apply( @@ -414,17 +594,8 @@ def api_access_sagemaker_role(self) -> aws.iam.Role: """ Return an execution role for APIGateway to access SageMaker endpoints. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_access_sagemaker_rol`is only available when public_internet_access is True" - ) - return aws.iam.Role( resource_name=f"{self._name}-api-sagemaker-access-role", opts=ResourceOptions(parent=self), @@ -440,17 +611,8 @@ def api_integration(self) -> aws.apigateway.Integration: """ Return a sagemaker integration for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_integration`is only available when public_internet_access is True" - ) - return aws.apigateway.Integration( resource_name=f"{self._name}-api-integration", opts=ResourceOptions(parent=self), @@ -469,17 +631,8 @@ def api_integration_response(self) -> aws.apigateway.IntegrationResponse: """ Return a sagemaker integration response for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_integration_response`is only available when public_internet_access is True" - ) - return aws.apigateway.IntegrationResponse( resource_name=f"{self._name}-api-integration-response", opts=ResourceOptions(parent=self, depends_on=[self.api_integration]), @@ -495,23 +648,14 @@ def api_method_response(self) -> aws.apigateway.MethodResponse: """ Return a sagemaker method response for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_method_response`is only available when public_internet_access is True" - ) - return aws.apigateway.MethodResponse( resource_name=f"{self._name}-api-method-response", opts=ResourceOptions(parent=self), rest_api=self.api.id, resource_id=self.api_resource_completions.id, - http_method="POST", + http_method=self.api_method.http_method, status_code="200", ) @@ -521,17 +665,8 @@ def api_deployment(self) -> aws.apigateway.Deployment: """ Return an API deployment for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_deploy`is only available when public_internet_access is True" - ) - return aws.apigateway.Deployment( resource_name=f"{self._name}-api-deploy", opts=ResourceOptions( @@ -551,17 +686,8 @@ def endpoint_base_url(self) -> pulumi.Output[str]: """ Return the base URL for the deployed endpoint. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`endpoint_base_url` is only available when public_internet_access is True" - ) - return pulumi.Output.all( self.api_deployment.invoke_url, self.api_resource_v1.path_part ).apply(lambda args: f"{args[0]}/{args[1]}") @@ -572,23 +698,14 @@ def endpoint_ssm_parameter(self) -> aws.ssm.Parameter: """ Return an SSM parameter that stores the deployed endpoint URL. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`endpoint_ssm_parameter`is only available when public_internet_access is True" - ) - return aws.ssm.Parameter( resource_name=f"{self._name}-endpoint-ssm-parameter", opts=ResourceOptions(parent=self), name=( - self.args.endpoint_ssm_parameter_name - if self.args.public_internet_access + self.args.api_key_required + if self.args.api_key_required else self.endpoint.endpoint_config_name ), type=aws.ssm.ParameterType.STRING, diff --git a/tests/clouds/aws/resources/test_vllm_component.py b/tests/clouds/aws/resources/test_vllm_component.py index 783bc9f..a65bedb 100644 --- a/tests/clouds/aws/resources/test_vllm_component.py +++ b/tests/clouds/aws/resources/test_vllm_component.py @@ -27,28 +27,32 @@ def call(self, args: MockCallArgs) -> Tuple[dict, Optional[List[Tuple[str, str]] ) -def test_private_internet_access(): +def test_require_api_key(): vllm = AwsVllmComponent( name="test", args=AwsVllmComponentArgs(), ) - with pytest.raises(AttributeError): - vllm.api - vllm.api_resource_completions - vllm.api_method - vllm.api_access_sagemaker_role - vllm.api_integration - vllm.api_integration_response - vllm.api_method_response - vllm.api_deployment + assert isinstance(vllm.api, aws.apigateway.RestApi) + assert isinstance(vllm.api_resource_completions, aws.apigateway.Resource) + assert isinstance(vllm.api_method, aws.apigateway.Method) + assert isinstance(vllm.api_access_sagemaker_role, aws.iam.Role) + assert isinstance(vllm.api_integration, aws.apigateway.Integration) + assert isinstance(vllm.api_integration_response, aws.apigateway.IntegrationResponse) + assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse) + assert isinstance(vllm.api_deployment, aws.apigateway.Deployment) + assert isinstance(vllm.admin_api_key, aws.apigateway.ApiKey) + assert isinstance(vllm.default_usage_plan, aws.apigateway.UsagePlan) + assert isinstance(vllm.api_key_usage_plan, aws.apigateway.UsagePlanKey) + assert isinstance(vllm.api_key_secret, aws.secretsmanager.Secret) + assert isinstance(vllm.api_key_secret_version, aws.secretsmanager.SecretVersion) def test_public_internet_access(): vllm = AwsVllmComponent( name="test", args=AwsVllmComponentArgs( - public_internet_access=True, + api_key_required=False, ), ) @@ -61,12 +65,20 @@ def test_public_internet_access(): assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse) assert isinstance(vllm.api_deployment, aws.apigateway.Deployment) + with pytest.raises(AttributeError): + vllm.admin_api_key + vllm.default_usage_plan + vllm.api_key_usage_plan + vllm.api_key_secret + vllm.api_key_secret_version + def test_model_image_version(): vllm = AwsVllmComponent( name="test", args=AwsVllmComponentArgs( model_image_version="0.29.0", + api_key_required=True, ), ) @@ -78,6 +90,7 @@ def test_model_image_config(): name="test", args=AwsVllmComponentArgs( model_name="microsoft/Phi-3-mini-4k-instruct", + api_key_required=True, ), )