diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 3c805682f3456f..161e65302ff8ad 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,6 +1,5 @@ from collections.abc import Generator from typing import cast -from urllib.parse import urljoin from httpx import Timeout from openai import ( @@ -19,6 +18,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion +from yarl import URL from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -181,7 +181,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: UserPromptMessage(content='ping') ], model_parameters={ 'max_tokens': 10, - }, stop=[]) + }, stop=[], stream=False) except Exception as ex: raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') @@ -227,6 +227,12 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ) ] + model_properties = { + ModelPropertyKey.MODE: completion_model, + } if completion_model else {} + + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + entity = AIModelEntity( model=model, label=I18nObject( @@ -234,7 +240,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {}, + model_properties=model_properties, parameter_rules=rules ) @@ -319,7 +325,7 @@ def _to_client_kwargs(self, credentials: dict) -> dict: client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": urljoin(credentials['server_url'], 'v1'), + "base_url": str(URL(credentials['server_url']) / 'v1'), } return client_kwargs diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index e4b625d171798c..a870914632a460 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -56,3 +56,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 + - variable: context_size + label: + zh_Hans: 上下文大小 + en_US: Context size + placeholder: + zh_Hans: 输入上下文大小 + en_US: Enter context size + required: false + type: text-input diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index c95007d271d207..954c9d10f2a67f 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -1,11 +1,12 @@ import time from json import JSONDecodeError, dumps -from os.path import join from typing import Optional from requests import post +from yarl import URL -from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -57,7 +58,7 @@ def _invoke(self, model: str, credentials: dict, } try: - response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) @@ -113,6 +114,27 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens + + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + Get customizable model schema + + :param model: model name + :param credentials: model credentials + :return: model schema + """ + return AIModelEntity( + model=model, + label=I18nObject(zh_Hans=model, en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + features=[], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[] + ) def validate_credentials(self, model: str, credentials: dict) -> None: """