Skip to content

Commit

Permalink
Merge branch 'main' into enum-without-logitBias
Browse files Browse the repository at this point in the history
  • Loading branch information
raulraja authored Jun 5, 2024
2 parents 6ea9778 + ed02034 commit ffc6a9a
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 1 deletion.
84 changes: 83 additions & 1 deletion core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.xebia.functional.openai.generated.api.Images
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.configuration.PromptConfiguration
import kotlin.coroutines.cancellation.CancellationException
Expand All @@ -14,15 +15,74 @@ import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.serializer

sealed interface AI {

@Serializable
@Description("The selected items indexes")
data class SelectedItems(
@Description("The selected items indexes") val selectedItems: List<Int>,
)

data class Classification(
val name: String,
val description: String,
)

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

interface PromptMultipleClassifier {
fun getItems(): List<Classification>

fun template(input: String): String {
val items = getItems()

return """
|Based on the <input>, identify whether the user is asking about one or more of the following items
|
|${
items.joinToString("\n") { item -> "<${item.name}>${item.description}</${item.name}>" }
}
|
|<items>
|${
items.mapIndexed { index, item -> "\t<item index=\"$index\">${item.name}</item>" }
.joinToString("\n")
}
|</items>
|<input>
|$input
|</input>
"""
}

@OptIn(ExperimentalSerializationApi::class)
fun KType.enumValuesName(
serializer: KSerializer<Any?> = serializer(this)
): List<Classification> {
return if (serializer.descriptor.kind != SerialKind.ENUM) {
emptyList()
} else {
(0 until serializer.descriptor.elementsCount).map { index ->
val name =
serializer.descriptor
.getElementName(index)
.removePrefix(serializer.descriptor.serialName)
val description =
(serializer.descriptor.getElementAnnotations(index).first { it is Description }
as Description)
.value
Classification(name, description)
}
}
}
}

companion object {

fun <A : Any> chat(
Expand Down Expand Up @@ -87,7 +147,7 @@ sealed interface AI {
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
target: KType = typeOf<E>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
Expand All @@ -104,6 +164,28 @@ sealed interface AI {
)
}

@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> multipleClassify(
input: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): List<E> where E : PromptMultipleClassifier, E : Enum<E> {
val values = enumValues<E>()
val value = values.firstOrNull() ?: error("No enum values found")
val selected: SelectedItems =
invoke(
prompt = value.template(input),
model = model,
config = config,
api = api,
conversation = conversation
)
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.xebia.functional.xef.dsl.classify

import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.conversation.Description
import kotlin.reflect.typeOf
import kotlinx.serialization.Serializable

@Serializable
enum class Sports : AI.PromptMultipleClassifier {
@Description(
"Football is a team sport that is played on a rectangular field with goalposts at each end. The objective of the game is to score points by moving the ball into the opposing team's goal. The team with the most points at the end of the game wins."
)
FOOTBALL,
@Description(
"The game of basketball is played with a ball and a hoop. The objective is to score points by shooting the ball through the hoop. The game is played on a rectangular court with a hoop at each end. The team with the most points at the end of the game wins."
)
BASKETBALL,
@Description(
"The game of tennis is played with a racket and a ball. The objective is to hit the ball over the net and into the opponent's court. The game is played on a rectangular court with a net at the center. The player with the most points at the end of the game wins."
)
VOLLEYBALL,
@Description(
"The game of cricket is played with a bat and a ball. The objective is to score runs by hitting the ball and running between the wickets. The game is played on a circular field with a wicket at each end. The team with the most runs at the end of the game wins."
)
CRICKET,
@Description(
"The game of chess is played on a square board with 64 squares arranged in an 8x8 grid. The objective is to checkmate the opponent's king by placing it under threat of capture. The player who checkmates the opponent's king wins the game."
)
CHESS;

override fun getItems(): List<AI.Classification> = typeOf<Sports>().enumValuesName()
}

/**
* This is a simple example of how to use the `AI.multipleClassify` function to classify a prompt
*/
suspend fun main() {

println(AI.multipleClassify<Sports>("Sport played with a racket"))
println(
AI.multipleClassify<Sports>(
input = "The game is played with a ball",
model = CreateChatCompletionRequestModel.gpt_4o
)
)
}

0 comments on commit ffc6a9a

Please sign in to comment.