Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OpenAI client more configurable #780

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.xebia.functional.xef.conversation.Conversation
data class AIConfig(
val tools: List<Tool<*>> = emptyList(),
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
val config: Config = Config(),
val config: Config = Config.Default,
val openAI: OpenAI = OpenAI(config, logRequests = false),
val api: Chat = openAI.chat,
val conversation: Conversation = Conversation()
Expand Down
174 changes: 145 additions & 29 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,149 @@ import io.ktor.client.plugins.logging.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import kotlin.math.pow
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json

sealed interface HttpClientRetryPolicy {
data object NoRetry : HttpClientRetryPolicy

data class ExponentialBackoff(
val backoffFactor: Double,
val interval: Duration,
val maxDelay: Duration,
val maxRetries: Int
) : HttpClientRetryPolicy

data class Incremental(val interval: Duration, val maxDelay: Duration, val maxRetries: Int) :
HttpClientRetryPolicy

private fun configureHttpRequestRetryPlugin(
maxNumberOfRetries: Int,
delayBlock: HttpRequestRetry.DelayContext.(Int) -> Long
): HttpRequestRetry.Configuration.() -> Unit = {
maxRetries = maxNumberOfRetries
retryIf { _, response -> !response.status.isSuccess() }
retryOnExceptionIf { _, _ -> true }
delayMillis(block = delayBlock)
}

fun applyConfiguration(): (HttpClientConfig<*>) -> Unit = { httpClientConfig ->
when (val policy = this) {
is ExponentialBackoff ->
httpClientConfig.install(
HttpRequestRetry,
configure =
configureHttpRequestRetryPlugin(policy.maxRetries) { retry ->
minOf(
policy.backoffFactor.pow(retry).toLong() * policy.interval.inWholeMilliseconds,
policy.maxDelay.inWholeMilliseconds
)
}
)
is Incremental ->
httpClientConfig.install(
HttpRequestRetry,
configure =
configureHttpRequestRetryPlugin(policy.maxRetries) { retry ->
minOf(
retry * policy.interval.inWholeMilliseconds,
policy.maxDelay.inWholeMilliseconds
)
}
)
NoRetry -> Unit
}
}
}

data class HttpClientTimeoutPolicy(
val connectTimeout: Duration,
val requestTimeout: Duration,
val socketTimeout: Duration
) {
val applyConfiguration: (HttpClientConfig<*>) -> Unit = { httpClientConfig ->
httpClientConfig.install(HttpTimeout) {
requestTimeoutMillis = requestTimeout.inWholeMilliseconds
connectTimeoutMillis = connectTimeout.inWholeMilliseconds
socketTimeoutMillis = socketTimeout.inWholeMilliseconds
}
}
}

class ConfigBuilder internal constructor(config: Config) {
var baseUrl: String = config.baseUrl
var httpClientRetryPolicy: HttpClientRetryPolicy = config.httpClientRetryPolicy
var httpClientTimeoutPolicy: HttpClientTimeoutPolicy = config.httpClientTimeoutPolicy
var apiToken: String? = config.apiToken
var organization: String? = config.organization
var json: Json = config.json
var streamingPrefix: String = config.streamingPrefix
var streamingDelimiter: String = config.streamingDelimiter

fun build(): Config =
Config(
baseUrl = baseUrl,
httpClientRetryPolicy = httpClientRetryPolicy,
httpClientTimeoutPolicy = httpClientTimeoutPolicy,
apiToken = apiToken,
organization = organization,
json = json,
streamingPrefix = streamingPrefix,
streamingDelimiter = streamingDelimiter
)
}

fun Config(from: Config = Config.Default, builderAction: ConfigBuilder.() -> Unit): Config {
val builder = ConfigBuilder(from)
builder.builderAction()
return builder.build()
}

data class Config(
val baseUrl: String = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/",
val token: String? = null,
val org: String? = getenv(ORG_ENV_VAR),
val json: Json = Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = TYPE_DISCRIMINATOR
},
val streamingPrefix: String = "data:",
val streamingDelimiter: String = "data: [DONE]"
val baseUrl: String,
val httpClientRetryPolicy: HttpClientRetryPolicy,
val httpClientTimeoutPolicy: HttpClientTimeoutPolicy,
val apiToken: String?,
val organization: String?,
val json: Json,
val streamingPrefix: String,
val streamingDelimiter: String
) {
companion object {
val DEFAULT = Config()
@OptIn(ExperimentalSerializationApi::class)
val Default =
Config(
baseUrl = getenv(HOST_ENV_VAR) ?: "https://api.openai.com/v1/",
httpClientRetryPolicy =
HttpClientRetryPolicy.Incremental(
interval = 250.milliseconds,
maxDelay = 5.seconds,
maxRetries = 5
),
httpClientTimeoutPolicy =
HttpClientTimeoutPolicy(
connectTimeout = 45.seconds,
requestTimeout = 45.seconds,
socketTimeout = 45.seconds
),
json =
Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = TYPE_DISCRIMINATOR
},
organization = getenv(ORG_ENV_VAR),
streamingDelimiter = "data: [DONE]",
streamingPrefix = "data:",
apiToken = null
)

const val TYPE_DISCRIMINATOR = "_type_"
}
}
Expand All @@ -43,33 +168,24 @@ private const val KEY_ENV_VAR = "OPENAI_TOKEN"
* Just simple fun on top of generated API.
*/
fun OpenAI(
config: Config = Config(),
config: Config = Config.Default,
httpClientEngine: HttpClientEngine? = null,
httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null,
logRequests: Boolean = false
): OpenAI {
val token =
config.token
config.apiToken
?: getenv(KEY_ENV_VAR)
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var"))
val clientConfig: HttpClientConfig<*>.() -> Unit = {
install(ContentNegotiation) { json(config.json) }
install(HttpTimeout) {
requestTimeoutMillis = 45 * 1000
connectTimeoutMillis = 45 * 1000
socketTimeoutMillis = 45 * 1000
}
install(HttpRequestRetry) {
maxRetries = 5
retryIf { _, response -> !response.status.isSuccess() }
retryOnExceptionIf { _, _ -> true }
delayMillis { retry -> retry * 1000L }
}
install(Logging) { level = if (logRequests) LogLevel.ALL else LogLevel.NONE }
install(Logging) { level = if (logRequests) LogLevel.ALL else LogLevel.INFO }
config.httpClientRetryPolicy.applyConfiguration().invoke(this)
config.httpClientTimeoutPolicy.applyConfiguration.invoke(this)
httpClientConfig?.invoke(this)
defaultRequest {
url(config.baseUrl)
config.org?.let { headers.append("OpenAI-Organization", it) }
config.organization?.let { headers.append("OpenAI-Organization", it) }
bearerAuth(token)
}
}
Expand All @@ -79,7 +195,7 @@ fun OpenAI(
OpenAIConfig(
baseUrl = config.baseUrl,
token = token,
org = config.org,
org = config.organization,
json = config.json,
streamingPrefix = config.streamingPrefix,
streamingDelimiter = config.streamingDelimiter
Expand Down
12 changes: 6 additions & 6 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ sealed class Tool<out A>(
val typeSerializer = targetClass.serializer()
val functionObject = chatFunction(typeSerializer.descriptor)
return Callable(functionObject) {
Config.DEFAULT.json.decodeFromString(typeSerializer, it.arguments)
Config.Default.json.decodeFromString(typeSerializer, it.arguments)
}
}

Expand All @@ -137,7 +137,7 @@ sealed class Tool<out A>(
val functionSerializer = Value.serializer(targetClass.serializer())
val functionObject = chatFunction(functionSerializer.descriptor)
return Primitive(functionObject) {
Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value
Config.Default.json.decodeFromString(functionSerializer, it.arguments).value
}
}

Expand All @@ -161,7 +161,7 @@ sealed class Tool<out A>(
}
val functionObject = chatFunction(functionSerializer.descriptor)
return Callable(functionObject) {
Config.DEFAULT.json.decodeFromString(functionSerializer, it.arguments).value as A
Config.Default.json.decodeFromString(functionSerializer, it.arguments).value as A
}
}

Expand Down Expand Up @@ -205,7 +205,7 @@ sealed class Tool<out A>(
descriptor: SerialDescriptor
): Enumeration<A> {
val enumSerializer = { value: String ->
Config.DEFAULT.json.decodeFromString(targetClass.serializer(), value) as A
Config.Default.json.decodeFromString(targetClass.serializer(), value) as A
}
val functionObject = chatFunction(descriptor)
val cases =
Expand Down Expand Up @@ -251,7 +251,7 @@ sealed class Tool<out A>(
sealedClassSerializer: SealedClassSerializer<out Any>
): A {
val newJson = descriptorChoice(it, functionObjectMap)
return Config.DEFAULT.json.decodeFromString(
return Config.Default.json.decodeFromString(
sealedClassSerializer,
Json.encodeToString(newJson)
) as A
Expand All @@ -263,7 +263,7 @@ sealed class Tool<out A>(
): JsonObject {
// adds a `type` field with the call.functionName serial name equivalent to the call arguments
val jsonWithDiscriminator =
Config.DEFAULT.json.decodeFromString(JsonElement.serializer(), call.arguments)
Config.Default.json.decodeFromString(JsonElement.serializer(), call.arguments)
val descriptor =
descriptors.values.firstOrNull { it.name.endsWith(call.functionName) }
?: error("No descriptor found for ${call.functionName}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fun chatFunction(descriptor: SerialDescriptor): FunctionObject {
@OptIn(ExperimentalSerializationApi::class)
fun functionSchema(descriptor: SerialDescriptor): JsonObject =
descriptor.annotations.filterIsInstance<Schema>().firstOrNull()?.value?.let {
Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it)
Config.Default.json.decodeFromString(JsonObject.serializer(), it)
} ?: buildJsonSchema(descriptor)

@OptIn(ExperimentalSerializationApi::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ import net.mamoe.yamlkt.toYamlElement
class Assistant(
val assistantId: String,
val toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
val config: Config = Config(),
val config: Config = Config.Default,
private val assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
) {

constructor(
assistantObject: AssistantObject,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
) : this(assistantObject.id, toolsConfig, config, assistantsApi)

Expand Down Expand Up @@ -99,7 +99,7 @@ class Assistant(
toolResources: CreateAssistantRequestToolResources? = null,
metadata: JsonObject? = null,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant =
Assistant(
Expand All @@ -120,7 +120,7 @@ class Assistant(
suspend operator fun invoke(
request: CreateAssistantRequest,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant {
val response = assistantsApi.createAssistant(request, configure = ::defaultConfig)
Expand All @@ -130,7 +130,7 @@ class Assistant(
suspend fun fromConfig(
request: String,
toolsConfig: List<Tool.Companion.ToolConfig<*, *>> = emptyList(),
config: Config = Config(),
config: Config = Config.Default,
assistantsApi: Assistants = OpenAI(config, logRequests = false).assistants,
): Assistant {
val parsed = Yaml.Default.decodeYamlMapFromString(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import kotlinx.serialization.json.JsonPrimitive
class AssistantThread(
val threadId: String,
val metric: Metric = Metric.EMPTY,
private val config: Config = Config(),
private val config: Config = Config.Default,
private val api: Assistants = OpenAI(config).assistants
) {

Expand Down Expand Up @@ -271,7 +271,7 @@ class AssistantThread(
messages: List<MessageWithFiles>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down Expand Up @@ -303,7 +303,7 @@ class AssistantThread(
messages: List<String>,
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down Expand Up @@ -333,7 +333,7 @@ class AssistantThread(
messages: List<CreateMessageRequest> = emptyList(),
metadata: JsonObject? = null,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand All @@ -351,7 +351,7 @@ class AssistantThread(
suspend operator fun invoke(
request: CreateThreadRequest,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand All @@ -364,7 +364,7 @@ class AssistantThread(
suspend operator fun invoke(
request: CreateThreadAndRunRequest,
metric: Metric = Metric.EMPTY,
config: Config = Config(),
config: Config = Config.Default,
api: Assistants = OpenAI(config).assistants
): AssistantThread =
AssistantThread(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ sealed interface RunDelta {
RunDeltaEvent.values().find {
type.replace(".", "").replace("_", "").equals(it.name, ignoreCase = true)
}
val json = Config.DEFAULT.json
val json = Config.Default.json
return when (event) {
RunDeltaEvent.ThreadCreated ->
ThreadCreated(json.decodeFromJsonElement(ThreadObject.serializer(), data))
Expand Down
Loading
Loading