Skip to content

Commit

Permalink
feat(AwsVllmComponent): Add AWS Cognito as authentication service for…
Browse files Browse the repository at this point in the history
… LLM applications
  • Loading branch information
bramelfrink committed Oct 8, 2024
1 parent 299d96f commit dd1f86b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 93 deletions.
126 changes: 34 additions & 92 deletions src/damavand/cloud/aws/resources/vllm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,25 @@ def __init__(
)

self.args = args

print(">>>> self.args: ", self.args)
_ = self.model
_ = self.endpoint_config
_ = self.endpoint

if self.args.public_internet_access:
_ = self.api
_ = self.api_resource_completions
_ = self.api_method
_ = self.api_integration
_ = self.api_integration_response
_ = self.api_method_response
_ = self.api_deployment
_ = self.endpoint_ssm_parameter
_ = self.api
_ = self.api_resource_v1
_ = self.api_resource_v1
_ = self.api_resource_completions

if not self.args.public_internet_access:
_ = self.api_authorizer

_ = self.api_method
_ = self.api_integration
_ = self.api_integration_response
_ = self.api_method_response
_ = self.api_deploy

def get_service_assume_policy(self, service: str) -> dict[str, Any]:
"""Return the assume role policy for the requested service.
Expand Down Expand Up @@ -265,17 +271,8 @@ def api(self) -> aws.apigateway.RestApi:
"""
Return a public API for the SageMaker endpoint.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api` is only available when public_internet_access is True"
)

return aws.apigateway.RestApi(
resource_name=f"{self._name}-api",
opts=ResourceOptions(parent=self),
Expand All @@ -290,17 +287,8 @@ def api_resource_v1(self) -> aws.apigateway.Resource:
"""
Return a resource for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_resource`is only available when public_internet_access is True"
)

return aws.apigateway.Resource(
resource_name=f"{self._name}-api-resource-v1",
opts=ResourceOptions(parent=self),
Expand Down Expand Up @@ -359,41 +347,40 @@ def api_resource_completions(self) -> aws.apigateway.Resource:
path_part="completions",
)


@property
@cache
def api_method(self) -> aws.apigateway.Method:
"""
Return a method for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_method`is only available when public_internet_access is True"
if self.args.public_internet_access:
return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
http_method="POST",
authorization="NONE",
)
else:
return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
http_method="POST",
authorization="COGNITO_USER_POOLS",
authorizer_id=self.api_authorizer.id,
)

return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource_completions.id,
http_method="POST",
authorization="NONE",
)

@property
def api_sagemaker_integration_uri(self) -> pulumi.Output[str]:
"""
Return the SageMaker model integration URI for the API Gateway
Raises
------
AttributeError
When public_internet_access is False.
"""

return self.endpoint.name.apply(
Expand All @@ -414,17 +401,8 @@ def api_access_sagemaker_role(self) -> aws.iam.Role:
"""
Return an execution role for APIGateway to access SageMaker endpoints.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_access_sagemaker_rol`is only available when public_internet_access is True"
)

return aws.iam.Role(
resource_name=f"{self._name}-api-sagemaker-access-role",
opts=ResourceOptions(parent=self),
Expand All @@ -440,17 +418,8 @@ def api_integration(self) -> aws.apigateway.Integration:
"""
Return a sagemaker integration for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_integration`is only available when public_internet_access is True"
)

return aws.apigateway.Integration(
resource_name=f"{self._name}-api-integration",
opts=ResourceOptions(parent=self),
Expand All @@ -469,17 +438,8 @@ def api_integration_response(self) -> aws.apigateway.IntegrationResponse:
"""
Return a sagemaker integration response for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_integration_response`is only available when public_internet_access is True"
)

return aws.apigateway.IntegrationResponse(
resource_name=f"{self._name}-api-integration-response",
opts=ResourceOptions(parent=self, depends_on=[self.api_integration]),
Expand All @@ -495,17 +455,8 @@ def api_method_response(self) -> aws.apigateway.MethodResponse:
"""
Return a sagemaker method response for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_method_response`is only available when public_internet_access is True"
)

return aws.apigateway.MethodResponse(
resource_name=f"{self._name}-api-method-response",
opts=ResourceOptions(parent=self),
Expand All @@ -521,17 +472,8 @@ def api_deployment(self) -> aws.apigateway.Deployment:
"""
Return an API deployment for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_deploy`is only available when public_internet_access is True"
)

return aws.apigateway.Deployment(
resource_name=f"{self._name}-api-deploy",
opts=ResourceOptions(
Expand Down
3 changes: 2 additions & 1 deletion tests/clouds/aws/resources/test_vllm_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from typing import Optional, Tuple, List

import pulumi
Expand Down Expand Up @@ -67,6 +66,7 @@ def test_model_image_version():
name="test",
args=AwsVllmComponentArgs(
model_image_version="0.29.0",
public_internet_access=True,
),
)

Expand All @@ -78,6 +78,7 @@ def test_model_image_config():
name="test",
args=AwsVllmComponentArgs(
model_name="microsoft/Phi-3-mini-4k-instruct",
public_internet_access=True,
),
)

Expand Down

0 comments on commit dd1f86b

Please sign in to comment.