diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index c4be732f2e5e2..a4cfbd171eb14 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -4,6 +4,8 @@ - anthropic.claude-v1 - anthropic.claude-v2 - anthropic.claude-v2:1 +- anthropic.claude-3-sonnet-v1:0 +- anthropic.claude-3-haiku-v1:0 - cohere.command-light-text-v14 - cohere.command-text-v14 - meta.llama2-13b-chat-v1 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml new file mode 100644 index 0000000000000..73fe5567fc266 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -0,0 +1,57 @@ +model: anthropic.claude-3-haiku-20240307-v1:0 +label: + en_US: Claude 3 Haiku +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml new file mode 100644 index 0000000000000..cb11df0b60183 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -0,0 +1,56 @@ +model: anthropic.claude-3-sonnet-20240229-v1:0 +label: + en_US: Claude 3 Sonnet +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c6aaa24ade174..5745721ae8d4a 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,9 +1,22 @@ +import base64 import json import logging +import mimetypes +import time from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast import boto3 +import requests +from anthropic import AnthropicBedrock, Stream +from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, +) from botocore.config import Config from botocore.exceptions import ( ClientError, @@ -13,14 +26,18 @@ UnknownServiceError, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -54,9 +71,268 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ + + # invoke claude 3 models via anthropic official SDK + if "anthropic.claude-3" in model: + return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream) # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]: + """ + Invoke Claude3 large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :return: full response or stream response chunk generator result + """ + # use Anthropic official SDK references + # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock + # - https://github.com/anthropics/anthropic-sdk-python + client = AnthropicBedrock( + aws_access_key=credentials["aws_access_key_id"], + aws_secret_key=credentials["aws_secret_access_key"], + aws_region=credentials["aws_region"], + ) + + system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages) + + response = client.messages.create( + model=model, + messages=prompt_message_dicts, + stop_sequences=stop if stop else [], + system=system, + stream=stream, + **model_parameters, + ) + + if stream is False: + return self._handle_claude3_response(model, credentials, response, prompt_messages) + else: + return self._handle_claude3_stream_response(model, credentials, response, prompt_messages) + + def _handle_claude3_response(self, model: str, credentials: dict, response: Message, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: full response chunk generator result + """ + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=response.content[0].text + ) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.input_tokens + completion_tokens = response.usage.output_tokens + else: + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + response = LLMResult( + model=response.model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def _handle_claude3_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], ) -> Generator: + """ + Handle llm chat stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: full response or stream response chunk generator result + """ + + try: + full_assistant_content = '' + return_model = None + input_tokens = 0 + output_tokens = 0 + finish_reason = None + index = 0 + + for chunk in response: + if isinstance(chunk, MessageStartEvent): + return_model = chunk.message.model + input_tokens = chunk.message.usage.input_tokens + elif isinstance(chunk, MessageDeltaEvent): + output_tokens = chunk.usage.output_tokens + finish_reason = chunk.delta.stop_reason + elif isinstance(chunk, MessageStopEvent): + usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) + yield LLMResultChunk( + model=return_model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index + 1, + message=AssistantPromptMessage( + content='' + ), + finish_reason=finish_reason, + usage=usage + ) + ) + elif isinstance(chunk, ContentBlockDeltaEvent): + chunk_text = chunk.delta.text if chunk.delta.text else '' + full_assistant_content += chunk_text + assistant_prompt_message = AssistantPromptMessage( + content=chunk_text if chunk_text else '', + ) + index = chunk.index + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) + ) + except Exception as ex: + raise InvokeError(str(ex)) + + def _calc_claude3_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param prompt_tokens: prompt tokens + :param completion_tokens: completion tokens + :return: usage + """ + # get prompt price info + prompt_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=prompt_tokens, + ) + + # get completion price info + completion_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.OUTPUT, + tokens=completion_tokens + ) + + # transform usage + usage = LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=prompt_price_info.unit_price, + prompt_price_unit=prompt_price_info.unit, + prompt_price=prompt_price_info.total_amount, + completion_tokens=completion_tokens, + completion_unit_price=completion_price_info.unit_price, + completion_price_unit=completion_price_info.unit, + completion_price=completion_price_info.total_amount, + total_tokens=prompt_tokens + completion_tokens, + total_price=prompt_price_info.total_amount + completion_price_info.total_amount, + currency=prompt_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def _convert_claude3_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + """ + Convert prompt messages to dict list and system + """ + system = "" + prompt_message_dicts = [] + + for message in prompt_messages: + if isinstance(message, SystemPromptMessage): + system += message.content + ("\n" if not system else "") + else: + prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message)) + + return system, prompt_message_dicts + + def _convert_claude3_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + if not message_content.data.startswith("data:"): + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + else: + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + + if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + raise ValueError(f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp") + + sub_message_dict = { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ @@ -101,7 +377,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + + if "anthropic.claude-3" in model: + try: + self._invoke_claude3(model=model, + credentials=credentials, + prompt_messages=[{"role": "user", "content": "ping"}], + model_parameters={}, + stop=None, + stream=False) + + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + try: ping_message = UserPromptMessage(content="ping") self._generate(model=model, diff --git a/api/requirements.txt b/api/requirements.txt index 7edd95a8933de..b8714291a96a2 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -36,7 +36,7 @@ python-docx~=1.1.0 pypdfium2==4.16.0 resend~=0.7.0 pyjwt~=2.8.0 -anthropic~=0.17.0 +anthropic~=0.20.0 newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0