Skip to content

Commit

Permalink
[DERCBOT-1251] Add model to AzureOpenAI Settings (#1781)
Browse files Browse the repository at this point in the history
  • Loading branch information
assouktim authored Nov 19, 2024
1 parent d390019 commit 1cc3cc9
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 14 deletions.
1 change: 1 addition & 0 deletions bot/admin/server/src/test/kotlin/service/RAGServiceTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class RAGServiceTest : AbstractTest() {
apiKey = "apiKey",
apiVersion = "apiVersion",
deploymentName = "deployment",
model = "model",
apiBase = "url"
),
noAnswerSentence = "No answer sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class RAGValidationServiceTest {
apiBase = "http://my-api-base-endpoint-url.com",
apiVersion = "2023-08-01-preview",
deploymentName = "deploymentName",
model = "model",
)

private val ragConfiguration = BotRAGConfigurationDTO(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export const EngineConfigurations: EnginesConfiguration[] = [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated', confirmExport: true },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIModelsList },
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const EnginesConfigurations_Llm: EnginesConfiguration[] = [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated', confirmExport: true },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIModelsList },
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
Expand Down Expand Up @@ -91,6 +92,7 @@ const EnginesConfigurations_Embedding: EnginesConfiguration[] = [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated', confirmExport: true },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIEmbeddingModel },
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' }
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const settings = {
params: {
modelName: 'gpt-4-32k',
deploymentName: 'azure deployment name',
model: 'model name',
privateEndpointBaseUrl: 'azure endpoint url',
apiVersion: '2023-03-15-preview',
embeddingDeploymentName: 'Embedding deployment name',
Expand Down
4 changes: 3 additions & 1 deletion docs/_en/dev/gen_ai_orchestrator/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@
class AzureOpenAILLMSetting(BaseLLMSetting):
provider: Literal[LLMProvider.AZURE_OPEN_AI_SERVICE]
deployment_name: str
model: Optional[str]
api_base: str
api_version: str

Expand All @@ -209,7 +210,8 @@
class AzureOpenAIEMSetting(BaseEMSetting):
provider: Literal[LLMProvider.AZURE_OPEN_AI_SERVICE]
deployment_name: str
api_vase: str
model: Optional[str]
api_base: str
api_version: str

EMSetting = Annotated[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ AzureOpenAIService <br />
"prompt": "Customized prompt for the use case",
"api_base": "https://custom-api-name.azure-api.net",
"deployment_name": "custom-deployment-name",
"model": "gpt-4o",
"api_version": "2024-03-01-preview"
}
</pre>
Expand All @@ -78,6 +79,7 @@ AzureOpenAIService <br />
},
"api_base": "https://custom-api-name.azure-api.net",
"deployment_name": "custom-deployment-name",
"model": "text-embedding-ada-002",
"api_version": "2024-03-01-preview"
}
</pre>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ object EMSettingMapper {
apiKey = SecurityUtils.fetchSecretKeyValue(apiKey),
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion
apiVersion = apiVersion,
model = model
)
is OllamaEMSetting ->
OllamaEMSetting(model = model, baseUrl = baseUrl)
Expand Down Expand Up @@ -75,7 +76,8 @@ object EMSettingMapper {
SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion
apiVersion = apiVersion,
model = model
)
is OllamaEMSetting ->
OllamaEMSetting(model = model, baseUrl = baseUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ object LLMSettingMapper {
prompt = prompt,
apiBase = apiBase,
deploymentName = deploymentName,
model = model,
apiVersion = apiVersion
)
is OllamaLLMSetting ->
Expand Down Expand Up @@ -84,11 +85,12 @@ object LLMSettingMapper {
is AzureOpenAILLMSetting ->
AzureOpenAILLMSetting(
SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
temperature,
prompt,
apiBase,
deploymentName,
apiVersion
temperature = temperature,
prompt = prompt,
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion,
model = model
)
is OllamaLLMSetting ->
OllamaLLMSetting(temperature, prompt, model, baseUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data class AzureOpenAIEMSetting<T>(
val apiBase: String,
val deploymentName: String,
val apiVersion: String,
val model: String? = null,
) : EMSettingBase<T>(EMProvider.AzureOpenAIService, apiKey)

typealias AzureOpenAIEMSettingDTO = AzureOpenAIEMSetting<String>
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ data class AzureOpenAILLMSetting<T>(
val apiBase: String,
val deploymentName: String,
val apiVersion: String,
val model: String? = null,
) : LLMSettingBase<T>(LLMProvider.AzureOpenAIService, apiKey, temperature, prompt) {
override fun copyWithTemperature(temperature: String): LLMSettingBase<T> {
return this.copy(temperature=temperature)
}
}
typealias AzureOpenAILLMSettingDTO = AzureOpenAILLMSetting<String>
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
"""Model for creating AzureOpenAIEMSetting."""

from typing import Literal
from typing import Literal, Optional

from pydantic import Field, HttpUrl

Expand Down Expand Up @@ -42,6 +42,9 @@ class AzureOpenAIEMSetting(BaseEMSetting):
description='The deployment name you chose when you deployed the model.',
examples=['my-deployment-name'],
)
model: Optional[str] = Field(
description='The model id', examples=['text-embedding-ada-002'], default=None
)
api_base: HttpUrl = Field(
description='The API base url / Azure endpoint',
examples=['https://doc.tock.ai/tock'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
"""Model for creating AzureOpenAILLMSetting."""

from typing import Literal
from typing import Literal, Optional

from pydantic import Field, HttpUrl

Expand Down Expand Up @@ -42,6 +42,9 @@ class AzureOpenAILLMSetting(BaseLLMSetting):
description='The deployment name you chose when you deployed the model.',
examples=['my-deployment-name'],
)
model: Optional[str] = Field(
description='The model id', examples=['gpt-3.5-turbo'], default=None
)
api_base: HttpUrl = Field(
description='The API base url / Azure endpoint',
examples=['https://doc.tock.ai/tock'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ async def get_em_provider_setting_by_id(
provider=EMProvider.AZURE_OPEN_AI_SERVICE,
api_key=RawSecretKey(value='ab7***************************A1IV4B'),
deployment_name='my-deployment-name',
model='text-embedding-ada-002',
api_base='https://doc.tock.ai/tock',
api_version='2023-05-15',
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ async def get_llm_provider_setting_by_id(
provider=LLMProvider.AZURE_OPEN_AI_SERVICE,
api_key=RawSecretKey(value='ab7***************************A1IV4B'),
deployment_name='my-deployment-name',
model='gpt-4o',
api_base='https://doc.tock.ai/tock',
api_version='2023-05-15',
temperature=0.7,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List

from langchain.embeddings.base import Embeddings
from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings

from gen_ai_orchestrator.configurations.environment.settings import (
application_settings,
Expand Down Expand Up @@ -46,6 +46,8 @@ def get_embedding_model(self) -> Embeddings:
openai_api_version=self.setting.api_version,
azure_endpoint=str(self.setting.api_base),
azure_deployment=self.setting.deployment_name,
# the model is not Nullable, it has a default value
model=self.setting.model or OpenAIEmbeddings.__fields__["model"].default,
timeout=application_settings.em_provider_timeout,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_language_model(self) -> BaseLanguageModel:
openai_api_version=self.setting.api_version,
azure_endpoint=str(self.setting.api_base),
azure_deployment=self.setting.deployment_name,
model=self.setting.model,
temperature=self.setting.temperature,
request_timeout=application_settings.llm_provider_timeout,
max_retries=application_settings.llm_provider_max_retries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def test_get_azure_open_ai_llm_factory():
'type': 'Raw',
'value': 'ab7***************************A1IV4B',
},
'deployment_name': 'model',
'deployment_name': 'deployment_name',
'model': 'gpt-4o',
'api_base': 'https://doc.tock.ai/tock',
'api_version': 'version',
'temperature': '0',
Expand Down Expand Up @@ -209,7 +210,8 @@ def test_get_azure_open_ai_em_factory():
'type': 'Raw',
'value': 'ab7***************************A1IV4B',
},
'deployment_name': 'model',
'deployment_name': 'deployment_name',
'model': 'text-embedding-ada-002',
'api_base': 'https://doc.tock.ai/tock',
'api_version': 'version',
'prompt': 'List 3 ice cream flavors.',
Expand Down

0 comments on commit 1cc3cc9

Please sign in to comment.