From 2a901cf55e396fc43347d26087106688a15be077 Mon Sep 17 00:00:00 2001 From: boon Date: Wed, 22 Nov 2023 16:31:26 +0800 Subject: [PATCH] extend init kwargs to bedrock --- llms/providers/bedrock_anthropic.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/llms/providers/bedrock_anthropic.py b/llms/providers/bedrock_anthropic.py index 1adb073..3eaf9b6 100644 --- a/llms/providers/bedrock_anthropic.py +++ b/llms/providers/bedrock_anthropic.py @@ -1,12 +1,13 @@ # llms/providers/bedrock_anthropic.py import os -from typing import Optional +from typing import Optional, Union import anthropic_bedrock from .anthropic import AnthropicProvider + class BedrockAnthropicProvider(AnthropicProvider): MODEL_INFO = { "anthropic.claude-instant-v1": {"prompt": 1.63, "completion": 5.51, "token_limit": 9000}, @@ -19,11 +20,13 @@ class BedrockAnthropicProvider(AnthropicProvider): } def __init__( - self, - model: Optional[str] = None, - aws_access_key: Optional[str] = None, - aws_secret_key: Optional[str] = None, - aws_region: Optional[str] = None + self, + model: Union[str, None] = None, + aws_access_key: Union[str, None] = None, + aws_secret_key: Union[str, None] = None, + aws_region: Union[str, None] = None, + client_kwargs: Union[dict, None] = None, + async_client_kwargs: Union[dict, None] = None, ): if model is None: model = list(self.MODEL_INFO.keys())[0] @@ -33,14 +36,21 @@ def __init__( aws_access_key = os.getenv("AWS_ACCESS_KEY_ID") if aws_secret_key is None: aws_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") + + if client_kwargs is None: + client_kwargs = {} self.client = anthropic_bedrock.AnthropicBedrock( aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, - aws_region=aws_region + aws_region=aws_region, + **client_kwargs, ) + + if async_client_kwargs is None: + async_client_kwargs = {} self.async_client = anthropic_bedrock.AsyncAnthropicBedrock( aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, - aws_region=aws_region + aws_region=aws_region, + **async_client_kwargs, ) -