Skip to content

Commit

Permalink
extend init kwargs to bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
bkiat1123 committed Nov 22, 2023
1 parent 712bfaf commit 2a901cf
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions llms/providers/bedrock_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# llms/providers/bedrock_anthropic.py

import os
from typing import Optional
from typing import Optional, Union

import anthropic_bedrock

from .anthropic import AnthropicProvider


class BedrockAnthropicProvider(AnthropicProvider):
MODEL_INFO = {
"anthropic.claude-instant-v1": {"prompt": 1.63, "completion": 5.51, "token_limit": 9000},
Expand All @@ -19,11 +20,13 @@ class BedrockAnthropicProvider(AnthropicProvider):
}

def __init__(
self,
model: Optional[str] = None,
aws_access_key: Optional[str] = None,
aws_secret_key: Optional[str] = None,
aws_region: Optional[str] = None
self,
model: Union[str, None] = None,
aws_access_key: Union[str, None] = None,
aws_secret_key: Union[str, None] = None,
aws_region: Union[str, None] = None,
client_kwargs: Union[dict, None] = None,
async_client_kwargs: Union[dict, None] = None,
):
if model is None:
model = list(self.MODEL_INFO.keys())[0]
Expand All @@ -33,14 +36,21 @@ def __init__(
aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
if aws_secret_key is None:
aws_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")

if client_kwargs is None:
client_kwargs = {}
self.client = anthropic_bedrock.AnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region
aws_region=aws_region,
**client_kwargs,
)

if async_client_kwargs is None:
async_client_kwargs = {}
self.async_client = anthropic_bedrock.AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region
aws_region=aws_region,
**async_client_kwargs,
)

0 comments on commit 2a901cf

Please sign in to comment.