diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt index a9d0f945f..4f5aa943e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt @@ -8,7 +8,7 @@ import com.xebia.functional.xef.conversation.Conversation data class AIConfig( val tools: List> = emptyList(), val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o, - val config: Config = Config(), + val config: Config = Config.Default, val openAI: OpenAI = OpenAI(config, logRequests = false), val api: Chat = openAI.chat, val conversation: Conversation = Conversation() diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt index 2f1cb2932..f610e5465 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt @@ -12,24 +12,149 @@ import io.ktor.client.plugins.logging.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.serialization.kotlinx.json.* +import kotlin.math.pow +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json +sealed interface HttpClientRetryPolicy { + data object NoRetry : HttpClientRetryPolicy + + data class ExponentialBackoff( + val backoffFactor: Double, + val interval: Duration, + val maxDelay: Duration, + val maxRetries: Int + ) : HttpClientRetryPolicy + + data class Incremental(val interval: Duration, val maxDelay: Duration, val maxRetries: Int) : + HttpClientRetryPolicy + + private fun configureHttpRequestRetryPlugin( + maxNumberOfRetries: Int, + delayBlock: HttpRequestRetry.DelayContext.(Int) -> Long + ): HttpRequestRetry.Configuration.() -> Unit = { + maxRetries = maxNumberOfRetries + retryIf { _, response -> !response.status.isSuccess() } + retryOnExceptionIf { _, _ -> true } + delayMillis(block = delayBlock) + } + + fun applyConfiguration(): (HttpClientConfig<*>) -> Unit = { httpClientConfig -> + when (val policy = this) { + is ExponentialBackoff -> + httpClientConfig.install( + HttpRequestRetry, + configure = + configureHttpRequestRetryPlugin(policy.maxRetries) { retry -> + minOf( + policy.backoffFactor.pow(retry).toLong() * policy.interval.inWholeMilliseconds, + policy.maxDelay.inWholeMilliseconds + ) + } + ) + is Incremental -> + httpClientConfig.install( + HttpRequestRetry, + configure = + configureHttpRequestRetryPlugin(policy.maxRetries) { retry -> + minOf( + retry * policy.interval.inWholeMilliseconds, + policy.maxDelay.inWholeMilliseconds + ) + } + ) + NoRetry -> Unit + } + } +} + +data class HttpClientTimeoutPolicy( + val connectTimeout: Duration, + val requestTimeout: Duration, + val socketTimeout: Duration +) { + val applyConfiguration: (HttpClientConfig<*>) -> Unit = { httpClientConfig -> + httpClientConfig.install(HttpTimeout) { + requestTimeoutMillis = requestTimeout.inWholeMilliseconds + connectTimeoutMillis = connectTimeout.inWholeMilliseconds + socketTimeoutMillis = socketTimeout.inWholeMilliseconds + } + } +} + +class ConfigBuilder internal constructor(config: Config) { + var baseUrl: String = config.baseUrl + var httpClientRetryPolicy: HttpClientRetryPolicy = config.httpClientRetryPolicy + var httpClientTimeoutPolicy: HttpClientTimeoutPolicy = config.httpClientTimeoutPolicy + var apiToken: String? = config.apiToken + var organization: String? = config.organization + var json: Json = config.json + var streamingPrefix: String = config.streamingPrefix + var streamingDelimiter: String = config.streamingDelimiter + + fun build(): Config = + Config( + baseUrl = baseUrl, + httpClientRetryPolicy = httpClientRetryPolicy, + httpClientTimeoutPolicy = httpClientTimeoutPolicy, + apiToken = apiToken, + organization = organization, + json = json, + streamingPrefix = streamingPrefix, + streamingDelimiter = streamingDelimiter + ) +} + +fun Config(from: Config = Config.Default, builderAction: ConfigBuilder.() -> Unit): Config { + val builder = ConfigBuilder(from) + builder.builderAction() + return builder.build() +} + data class Config( - val baseUrl: String = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/", - val token: String? = null, - val org: String? = getenv(ORG_ENV_VAR), - val json: Json = Json { - ignoreUnknownKeys = true - prettyPrint = false - isLenient = true - explicitNulls = false - classDiscriminator = TYPE_DISCRIMINATOR - }, - val streamingPrefix: String = "data:", - val streamingDelimiter: String = "data: [DONE]" + val baseUrl: String, + val httpClientRetryPolicy: HttpClientRetryPolicy, + val httpClientTimeoutPolicy: HttpClientTimeoutPolicy, + val apiToken: String?, + val organization: String?, + val json: Json, + val streamingPrefix: String, + val streamingDelimiter: String ) { companion object { - val DEFAULT = Config() + @OptIn(ExperimentalSerializationApi::class) + val Default = + Config( + baseUrl = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/", + httpClientRetryPolicy = + HttpClientRetryPolicy.Incremental( + interval = 250.milliseconds, + maxDelay = 5.seconds, + maxRetries = 5 + ), + httpClientTimeoutPolicy = + HttpClientTimeoutPolicy( + connectTimeout = 45.seconds, + requestTimeout = 45.seconds, + socketTimeout = 45.seconds + ), + json = + Json { + ignoreUnknownKeys = true + prettyPrint = false + isLenient = true + explicitNulls = false + classDiscriminator = TYPE_DISCRIMINATOR + }, + organization = getenv(ORG_ENV_VAR), + streamingDelimiter = "data: [DONE]", + streamingPrefix = "data:", + apiToken = null + ) + const val TYPE_DISCRIMINATOR = "_type_" } } @@ -43,33 +168,24 @@ private const val KEY_ENV_VAR = "OPENAI_TOKEN" * Just simple fun on top of generated API. */ fun OpenAI( - config: Config = Config(), + config: Config = Config.Default, httpClientEngine: HttpClientEngine? = null, httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null, logRequests: Boolean = false ): OpenAI { val token = - config.token + config.apiToken ?: getenv(KEY_ENV_VAR) ?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var")) val clientConfig: HttpClientConfig<*>.() -> Unit = { install(ContentNegotiation) { json(config.json) } - install(HttpTimeout) { - requestTimeoutMillis = 45 * 1000 - connectTimeoutMillis = 45 * 1000 - socketTimeoutMillis = 45 * 1000 - } - install(HttpRequestRetry) { - maxRetries = 5 - retryIf { _, response -> !response.status.isSuccess() } - retryOnExceptionIf { _, _ -> true } - delayMillis { retry -> retry * 1000L } - } - install(Logging) { level = if (logRequests) LogLevel.ALL else LogLevel.NONE } + install(Logging) { level = if (logRequests) LogLevel.ALL else LogLevel.INFO } + config.httpClientRetryPolicy.applyConfiguration().invoke(this) + config.httpClientTimeoutPolicy.applyConfiguration.invoke(this) httpClientConfig?.invoke(this) defaultRequest { url(config.baseUrl) - config.org?.let { headers.append("OpenAI-Organization", it) } + config.organization?.let { headers.append("OpenAI-Organization", it) } bearerAuth(token) } } @@ -79,7 +195,7 @@ fun OpenAI( OpenAIConfig( baseUrl = config.baseUrl, token = token, - org = config.org, + org = config.organization, json = config.json, streamingPrefix = config.streamingPrefix, streamingDelimiter = config.streamingDelimiter diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt index a0b20a62e..4317b346e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt @@ -128,7 +128,7 @@ sealed class Tool( val typeSerializer = targetClass.serializer() val functionObject = chatFunction(typeSerializer.descriptor) return Callable(functionObject) { - Config.DEFAULT.json.decodeFromString(typeSerializer, it.arguments) + Config.Default.json.decodeFromString(typeSerializer, it.arguments) } } @@ -137,7 +137,7 @@ sealed class Tool( val functionSerializer = Value.serializer(targetClass.serializer()) val functionObject = chatFunction(functionSerializer.descriptor) return Primitive(functionObject) { - Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value + Config.Default.json.decodeFromString(functionSerializer, it.arguments).value } } @@ -161,7 +161,7 @@ sealed class Tool( } val functionObject = chatFunction(functionSerializer.descriptor) return Callable(functionObject) { - Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value as A + Config.Default.json.decodeFromString(functionSerializer, it.arguments).value as A } } @@ -205,7 +205,7 @@ sealed class Tool( descriptor: SerialDescriptor ): Enumeration { val enumSerializer = { value: String -> - Config.DEFAULT.json.decodeFromString(targetClass.serializer(), value) as A + Config.Default.json.decodeFromString(targetClass.serializer(), value) as A } val functionObject = chatFunction(descriptor) val cases = @@ -251,7 +251,7 @@ sealed class Tool( sealedClassSerializer: SealedClassSerializer ): A { val newJson = descriptorChoice(it, functionObjectMap) - return Config.DEFAULT.json.decodeFromString( + return Config.Default.json.decodeFromString( sealedClassSerializer, Json.encodeToString(newJson) ) as A @@ -263,7 +263,7 @@ sealed class Tool( ): JsonObject { // adds a `type` field with the call.functionName serial name equivalent to the call arguments val jsonWithDiscriminator = - Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments) + Config.Default.json.decodeFromString(JsonElement.serializer(), call.arguments) val descriptor = descriptors.values.firstOrNull { it.name.endsWith(call.functionName) } ?: error("No descriptor found for ${call.functionName}") diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index bbbce296a..7b212c689 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -36,7 +36,7 @@ fun chatFunction(descriptor: SerialDescriptor): FunctionObject { @OptIn(ExperimentalSerializationApi::class) fun functionSchema(descriptor: SerialDescriptor): JsonObject = descriptor.annotations.filterIsInstance().firstOrNull()?.value?.let { - Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it) + Config.Default.json.decodeFromString(JsonObject.serializer(), it) } ?: buildJsonSchema(descriptor) @OptIn(ExperimentalSerializationApi::class) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt index 484e1492b..245184b62 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/Assistant.kt @@ -34,14 +34,14 @@ import net.mamoe.yamlkt.toYamlElement class Assistant( val assistantId: String, val toolsConfig: List> = emptyList(), - val config: Config = Config(), + val config: Config = Config.Default, private val assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ) { constructor( assistantObject: AssistantObject, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ) : this(assistantObject.id, toolsConfig, config, assistantsApi) @@ -99,7 +99,7 @@ class Assistant( toolResources: CreateAssistantRequestToolResources? = null, metadata: JsonObject? = null, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant = Assistant( @@ -120,7 +120,7 @@ class Assistant( suspend operator fun invoke( request: CreateAssistantRequest, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant { val response = assistantsApi.createAssistant(request, configure = ::defaultConfig) @@ -130,7 +130,7 @@ class Assistant( suspend fun fromConfig( request: String, toolsConfig: List> = emptyList(), - config: Config = Config(), + config: Config = Config.Default, assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants, ): Assistant { val parsed = Yaml.Default.decodeYamlMapFromString(request) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt index f0c27edc4..64c6cb67e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/AssistantThread.kt @@ -20,7 +20,7 @@ import kotlinx.serialization.json.JsonPrimitive class AssistantThread( val threadId: String, val metric: Metric = Metric.EMPTY, - private val config: Config = Config(), + private val config: Config = Config.Default, private val api: Assistants = OpenAI(config).assistants ) { @@ -271,7 +271,7 @@ class AssistantThread( messages: List, metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -303,7 +303,7 @@ class AssistantThread( messages: List, metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -333,7 +333,7 @@ class AssistantThread( messages: List = emptyList(), metadata: JsonObject? = null, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -351,7 +351,7 @@ class AssistantThread( suspend operator fun invoke( request: CreateThreadRequest, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( @@ -364,7 +364,7 @@ class AssistantThread( suspend operator fun invoke( request: CreateThreadAndRunRequest, metric: Metric = Metric.EMPTY, - config: Config = Config(), + config: Config = Config.Default, api: Assistants = OpenAI(config).assistants ): AssistantThread = AssistantThread( diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt index fec396e02..61537ab26 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/RunDelta.kt @@ -144,7 +144,7 @@ sealed interface RunDelta { RunDeltaEvent.values().find { type.replace(".", "").replace("_", "").equals(it.name, ignoreCase = true) } - val json = Config.DEFAULT.json + val json = Config.Default.json return when (event) { RunDeltaEvent.ThreadCreated -> ThreadCreated(json.decodeFromJsonElement(ThreadObject.serializer(), data)) diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/ConfigTests.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/ConfigTests.kt new file mode 100644 index 000000000..5c9955c50 --- /dev/null +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/ConfigTests.kt @@ -0,0 +1,46 @@ +package com.xebia.functional.xef + +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import kotlin.time.Duration.Companion.milliseconds +import kotlinx.serialization.json.Json + +class ConfigTests : + StringSpec({ + val newApiToken = "new-openai-token" + val newBaseUrl = "new-openai-url" + val newHttpClientRetryPolicy = HttpClientRetryPolicy.NoRetry + val newHttpClientTimeoutPolicy = + HttpClientTimeoutPolicy(5.milliseconds, 10.milliseconds, 15.milliseconds) + val newOrganization = "new-openai-organization" + val newStreamingDelimiter = "new-streaming-delimiter" + val newStreamingPrefix = "new-streaming-prefix" + + "Config Builder returns the default config if no values are provided" { + val config = Config {} + + config shouldBe Config.Default + } + + "Config Builder changes the default values with the provided by " { + val config = Config { + apiToken = newApiToken + baseUrl = newBaseUrl + httpClientRetryPolicy = newHttpClientRetryPolicy + httpClientTimeoutPolicy = newHttpClientTimeoutPolicy + json = Json.Default + organization = newOrganization + streamingDelimiter = newStreamingDelimiter + streamingPrefix = newStreamingPrefix + } + + config.apiToken shouldBe newApiToken + config.baseUrl shouldBe newBaseUrl + config.httpClientRetryPolicy shouldBe newHttpClientRetryPolicy + config.httpClientTimeoutPolicy shouldBe newHttpClientTimeoutPolicy + config.json shouldBe Json.Default + config.organization shouldBe newOrganization + config.streamingDelimiter shouldBe newStreamingDelimiter + config.streamingPrefix shouldBe newStreamingPrefix + } + }) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt index 8fecf412b..85885f939 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/assistants/AssistantStreaming.kt @@ -18,7 +18,7 @@ suspend fun main() { assistantId = "asst_BwQvmWIbGUMDvCuXOtAFH8B6", toolsConfig = listOf(Tool.toolOf(SumTool())) ) - val config = Config(org = null) + val config = Config { organization = null } val api = OpenAI(config = config, logRequests = false).assistants val thread = AssistantThread(api = api, metric = metric) println("Welcome to the Math tutor, ask me anything about math:") diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt index 5a942d26b..fbe81e8ba 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/audio/SimpleSpeech.kt @@ -14,7 +14,7 @@ import java.io.File import javax.media.bean.playerbean.MediaPlayer suspend fun main() { - val config = Config() + val config = Config.Default val audio = OpenAI(config).audio println("ask me something!") while (true) { diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt index 0778dc5c4..1964d0f3f 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt @@ -9,10 +9,10 @@ class LocalVectorStoreService : VectorStoreService() { override fun getVectorStore(token: String?, org: String?): VectorStore = LocalVectorStore( OpenAI( - Config( - token = token, - org = org, - ) + Config { + apiToken = token + organization = org + } ) .embeddings ) diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt index 6f764ea4e..25993aa51 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt @@ -40,10 +40,10 @@ class PostgresVectorStoreService( override fun getVectorStore(token: String?, org: String?): VectorStore { val embeddingsApi = OpenAI( - Config( - token = token, - org = org, - ) + Config { + apiToken = token + organization = org + } ) .embeddings