diff --git a/README.md b/README.md
index 23568ed2..a453a8b1 100644
--- a/README.md
+++ b/README.md
@@ -76,18 +76,18 @@ Maven:
com.simiacryptus
skyenet-webui
- 1.0.26
+ 1.0.27
```
Gradle:
```groovy
-implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.26'
+implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.27'
```
```kotlin
-implementation("com.simiacryptus:skyenet:1.0.26")
+implementation("com.simiacryptus:skyenet:1.0.27")
```
### 🌟 To Use
diff --git a/core/build.gradle.kts b/core/build.gradle.kts
index d3b1c28f..f81bdaca 100644
--- a/core/build.gradle.kts
+++ b/core/build.gradle.kts
@@ -26,17 +26,15 @@ kotlin {
jvmToolchain(11)
}
-val kotlin_version = "1.9.20"
-val junit_version = "5.9.2"
-val jetty_version = "11.0.17"
+val junit_version = "5.10.1"
val logback_version = "1.2.12"
dependencies {
- implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27")
+ implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28")
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
- implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
+ implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
compileOnlyApi(kotlin("stdlib"))
implementation(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.7.3")
@@ -48,15 +46,13 @@ dependencies {
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version)
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version)
-// compileOnlyApi(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
compileOnlyApi(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
- compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
+ compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
compileOnlyApi(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
compileOnlyApi(group = "ch.qos.logback", name = "logback-core", version = logback_version)
-// testImplementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
testImplementation(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
- testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
+ testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version)
diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt
new file mode 100644
index 00000000..09dc2be9
--- /dev/null
+++ b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt
@@ -0,0 +1,188 @@
+package com.simiacryptus.skyenet.actors.opt
+
+import com.simiacryptus.openai.OpenAIClient
+import com.simiacryptus.skyenet.actors.opt.ActorOptimization.GeneticApi.Prompt
+import com.simiacryptus.openai.proxy.ChatProxy
+import com.simiacryptus.skyenet.actors.BaseActor
+import com.simiacryptus.util.describe.Description
+import org.slf4j.LoggerFactory
+import kotlin.math.pow
+
+open class ActorOptimization(
+ val api: OpenAIClient,
+ val model: OpenAIClient.Models = OpenAIClient.Models.GPT35Turbo,
+ private val mutationRate: Double = 0.5,
+ private val mutatonTypes: Map = mapOf(
+ "Rephrase" to 1.0,
+ "Randomize" to 1.0,
+ "Summarize" to 1.0,
+ "Expand" to 1.0,
+ "Reorder" to 1.0,
+ "Remove Duplicate" to 1.0,
+ )
+) {
+
+ data class TestCase(
+ val userMessages: List,
+ val expectations: List,
+ val retries: Int = 3
+ )
+
+ open fun runGeneticGenerations(
+ prompts: List,
+ testCases: List,
+ actorFactory: (String) -> BaseActor,
+ resultMapper: (T) -> String,
+ selectionSize: Int = defaultSelectionSize(prompts),
+ populationSize: Int = defaultPositionSize(selectionSize, prompts),
+ generations: Int = 3
+ ): List {
+ var topPrompts = regenerate(prompts, populationSize)
+ for (generation in 0..generations) {
+ val scores = topPrompts.map { prompt ->
+ prompt to testCases.map { testCase ->
+ val answer = actorFactory(prompt).answer(*testCase.userMessages.toTypedArray(), api = api)
+ testCase.expectations.map { it.score(api, resultMapper(answer)) }.average()
+ }.average()
+ }
+ scores.sortedByDescending { it.second }.forEach {
+ log.info("""Scored ${it.second}: ${it.first.replace("\n", "\\n")}""")
+ }
+ if (generation == generations) {
+ log.info("Final generation: ${topPrompts.first()}")
+ break
+ } else {
+ val survivors = scores.sortedByDescending { it.second }.take(selectionSize).map { it.first }
+ topPrompts = regenerate(survivors, populationSize)
+ log.info("Generation $generation: ${topPrompts.first()}")
+ }
+ }
+ return topPrompts
+ }
+
+ private fun defaultPositionSize(selectionSize: Int, systemPrompts: List) =
+ Math.max(Math.max(selectionSize, 5), systemPrompts.size)
+
+ private fun defaultSelectionSize(systemPrompts: List) =
+ Math.max(Math.ceil(Math.log((systemPrompts.size + 1).toDouble()) / Math.log(2.0)), 3.0)
+ .toInt()
+
+ open fun regenerate(progenetors: List, desiredCount: Int): List {
+ val result = listOf().toMutableList()
+ result += progenetors
+ while (result.size < desiredCount) {
+ if (progenetors.size == 1) {
+ val selected = progenetors.first()
+ val mutated = mutate(selected)
+ result += mutated
+ } else if (progenetors.size == 0) {
+ throw RuntimeException("No survivors")
+ } else {
+ val a = progenetors.random()
+ var b: String
+ do {
+ b = progenetors.random()
+ } while (a == b)
+ val child = recombine(a, b)
+ result += child
+ }
+ }
+ return result
+ }
+
+ open fun recombine(a: String, b: String): String {
+ val temperature = 0.3
+ for (retry in 0..3) {
+ try {
+ val child = geneticApi(temperature.pow(1.0 / (retry + 1))).recombine(Prompt(a), Prompt(b)).prompt
+ if (child.contentEquals(a) || child.contentEquals(b)) {
+ log.info("Recombine failure: retry $retry")
+ continue
+ }
+ log.info(
+ "Recombined Prompts\n\t${
+ a.replace("\n", "\n\t")
+ }\n\t-- + --\n\t${
+ b.replace("\n", "\n\t")
+ }\n\t-- -> --\n\t${child.replace("\n", "\n\t")}"
+ )
+ if (Math.random() < mutationRate) {
+ return mutate(child)
+ } else {
+ return child
+ }
+ } catch (e: Exception) {
+ log.warn("Failed to recombine $a + $b", e)
+ }
+ }
+ return mutate(a)
+ }
+
+ open fun mutate(selected: String): String {
+ val temperature = 0.3
+ for (retry in 0..10) {
+ try {
+ val directive = getMutationDirective()
+ val mutated = geneticApi(temperature.pow(1.0 / (retry + 1))).mutate(Prompt(selected), directive).prompt
+ if (mutated.contentEquals(selected)) {
+ log.info("Mutate failure $retry ($directive): ${selected.replace("\n", "\\n")}")
+ continue
+ }
+ log.info(
+ "Mutated Prompt\n\t${selected.replace("\n", "\n\t")}\n\t-- -> --\n\t${
+ mutated.replace(
+ "\n",
+ "\n\t"
+ )
+ }"
+ )
+ return mutated
+ } catch (e: Exception) {
+ log.warn("Failed to mutate $selected", e)
+ }
+ }
+ throw RuntimeException("Failed to mutate $selected")
+ }
+
+ open fun getMutationDirective(): String {
+ val fate = mutatonTypes.values.sum() * Math.random()
+ var cumulative = 0.0
+ for ((key, value) in mutatonTypes) {
+ cumulative += value
+ if (fate < cumulative) {
+ return key
+ }
+ }
+ return mutatonTypes.keys.random()
+ }
+
+ protected interface GeneticApi {
+ @Description("Mutate the given prompt; rephrase, make random edits, etc.")
+ fun mutate(
+ systemPrompt: Prompt,
+ directive: String = "Rephrase"
+ ): Prompt
+
+ @Description("Recombine the given prompts to produce a third with about the same length; swap phrases, reword, etc.")
+ fun recombine(
+ systemPromptA: Prompt,
+ systemPromptB: Prompt
+ ): Prompt
+
+ data class Prompt(
+ val prompt: String
+ )
+ }
+
+ protected open fun geneticApi(temperature: Double = 0.3) = ChatProxy(
+ clazz = GeneticApi::class.java,
+ api = api,
+ model = model,
+ temperature = temperature
+ ).create()
+
+ companion object {
+ val log = LoggerFactory.getLogger(ActorOptimization::class.java)
+ }
+
+}
diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt
new file mode 100644
index 00000000..a58b0145
--- /dev/null
+++ b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt
@@ -0,0 +1,68 @@
+package com.simiacryptus.skyenet.actors.opt
+
+import com.simiacryptus.openai.OpenAIClient
+import com.simiacryptus.openai.opt.DistanceType
+import org.slf4j.LoggerFactory
+
+abstract class Expectation {
+ companion object {
+ val log = LoggerFactory.getLogger(Expectation::class.java)
+ }
+
+ open class VectorMatch(val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() {
+ override fun matches(api: OpenAIClient, response: String): Boolean {
+ return true
+ }
+
+ override fun score(api: OpenAIClient, response: String): Double {
+ val contentEmbedding = createEmbedding(api, example)
+ val promptEmbedding = createEmbedding(api, response)
+ val distance = metric.distance(contentEmbedding, promptEmbedding)
+ log.info(
+ """Distance = $distance
+ | from "${example.replace("\n", "\\n")}"
+ | to "${response.replace("\n", "\\n")}"
+ """.trimMargin().trim()
+ )
+ return -distance
+ }
+
+ private fun createEmbedding(api: OpenAIClient, str: String) = api.createEmbedding(
+ OpenAIClient.EmbeddingRequest(
+ model = OpenAIClient.Models.AdaEmbedding.modelName, input = str
+ )
+ ).data.first().embedding!!
+ }
+
+ open class ContainsMatch(
+ val pattern: Regex,
+ val critical: Boolean = true
+ ) : Expectation() {
+ override fun matches(api: OpenAIClient, response: String): Boolean {
+ if (!critical) return true
+ return _matches(response)
+ }
+ override fun score(api: OpenAIClient, response: String): Double {
+ return if (_matches(response)) 1.0 else 0.0
+ }
+
+ private fun _matches(response: String?): Boolean {
+ if (pattern.containsMatchIn(response ?: "")) return true
+ log.info(
+ """Failed to match ${
+ pattern.pattern.replace("\n", "\\n")
+ } in ${
+ response?.replace("\n", "\\n") ?: ""
+ }"""
+ )
+ return false
+ }
+
+ }
+
+ abstract fun matches(api: OpenAIClient, response: String): Boolean
+
+ abstract fun score(api: OpenAIClient, response: String): Double
+
+
+}
\ No newline at end of file
diff --git a/gradle.properties b/gradle.properties
index 24a56afd..d4ede578 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -1,6 +1,6 @@
# Gradle Releases -> https://github.com/gradle/gradle/releases
libraryGroup = com.simiacryptus.skyenet
-libraryVersion = 1.0.26
+libraryVersion = 1.0.27
gradleVersion = 7.6.1
# Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library
diff --git a/groovy/build.gradle.kts b/groovy/build.gradle.kts
index b9da3af8..8dcca396 100644
--- a/groovy/build.gradle.kts
+++ b/groovy/build.gradle.kts
@@ -37,10 +37,10 @@ dependencies {
implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version)
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
- implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
+ implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
- testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
- testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
+ testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
+ testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
}
diff --git a/java/build.gradle.kts b/java/build.gradle.kts
index 3680e1e6..973244a3 100644
--- a/java/build.gradle.kts
+++ b/java/build.gradle.kts
@@ -36,10 +36,10 @@ dependencies {
implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version)
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
- implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
+ implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
- testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
- testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
+ testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
+ testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
testImplementation(group = "org.jetbrains.kotlin", name = "kotlin-script-runtime", version = kotlin_version)
}
diff --git a/kotlin/build.gradle.kts b/kotlin/build.gradle.kts
index cf18c74f..ac30dcf8 100644
--- a/kotlin/build.gradle.kts
+++ b/kotlin/build.gradle.kts
@@ -40,10 +40,10 @@ dependencies {
implementation(kotlin("compiler-embeddable"))
implementation(kotlin("scripting-compiler-embeddable"))
- implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
+ implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
- testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
- testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
+ testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
+ testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.4.11")
diff --git a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt
index 7d16cd6a..27c48e82 100644
--- a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt
+++ b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt
@@ -1,3 +1,5 @@
+@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
+
package com.simiacryptus.skyenet.heart
import com.simiacryptus.skyenet.Heart
@@ -77,11 +79,6 @@ open class KotlinInterpreter(
val configuration = CompilerConfiguration().apply {
put(CLIConfigurationKeys.MESSAGE_COLLECTOR_KEY, messageCollector)
- //addKotlinSourceRoot(tempFile.absolutePath)
-// configureJdkClasspathRoots()
-// System.getProperty("java.class.path").split(File.pathSeparator).forEach {
-// addJvmClasspathRoot(File(it))
-// }
val k2JVMCompilerArguments = K2JVMCompilerArguments()
k2JVMCompilerArguments.fragmentSources = arrayOf(tempFile.absolutePath)
k2JVMCompilerArguments.classpath = System.getProperty("java.class.path")
@@ -90,11 +87,9 @@ open class KotlinInterpreter(
this.setupJvmSpecificArguments(k2JVMCompilerArguments)
this.setupCommonArguments(k2JVMCompilerArguments)
this.setupLanguageVersionSettings(k2JVMCompilerArguments)
-// put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true)
put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, k2JVMCompilerArguments.moduleName!!)
}
- // Create the compiler environment
val environment = KotlinCoreEnvironment.createForProduction(
parentDisposable = {},
configuration = configuration,
@@ -110,8 +105,8 @@ open class KotlinInterpreter(
if (errors.isEmpty()) null
else RuntimeException(
"""
- |${errors.joinToString("\n") { "Error: " + it }}
- |${warnings.joinToString("\n") { "Warning: " + it }}
+ |${errors.joinToString("\n") { "Error: $it" }}
+ |${warnings.joinToString("\n") { "Warning: $it" }}
""".trimMargin())
}
} catch (e: CompilationException) {
@@ -184,7 +179,7 @@ open class KotlinInterpreter(
defs.entrySet().forEach { (key, value) ->
val uuid = storageMap.getOrPut(value) { UUID.randomUUID() }
retrievalIndex.put(uuid, WeakReference(value))
- val fqClassName = KotlinInterpreter.javaClass.name.replace("$", ".")
+ val fqClassName = KotlinInterpreter::class.java.name.replace("$", ".")
val typeStr = typeOf(value)
out.add("val $key : $typeStr = $fqClassName.retrievalIndex.get(java.util.UUID.fromString(\"$uuid\"))?.get()!! as $typeStr\n")
}
@@ -192,7 +187,7 @@ open class KotlinInterpreter(
return out.joinToString("\n")
}
- open fun typeOf(value: Object?): String {
+ open fun typeOf(value: Any?): String {
if (value is java.lang.reflect.Proxy) {
return value.javaClass.interfaces[0].name.replace("$", ".") + "?"
}
diff --git a/webui/build.gradle.kts b/webui/build.gradle.kts
index 69b5b317..468cdbdd 100644
--- a/webui/build.gradle.kts
+++ b/webui/build.gradle.kts
@@ -28,11 +28,11 @@ kotlin {
}
val kotlin_version = "1.9.20"
-val jetty_version = "11.0.17"
-val jackson_version = "2.15.2"
+val jetty_version = "11.0.18"
+val jackson_version = "2.15.3"
dependencies {
- implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27")
+ implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28")
implementation(project(":core"))
testImplementation(project(":groovy"))
@@ -61,15 +61,15 @@ dependencies {
implementation(group = "com.google.api-client", name = "google-api-client", version = "1.35.2")
implementation(group = "com.google.oauth-client", name = "google-oauth-client-jetty", version = "1.34.1")
implementation(group = "com.google.apis", name = "google-api-services-oauth2", version = "v2-rev157-1.25.0")
- implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
+ implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
implementation(group = "commons-codec", name = "commons-codec", version = "1.16.0")
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
testImplementation(group = "org.slf4j", name = "slf4j-simple", version = "2.0.9")
testImplementation(kotlin("script-runtime"))
- testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
- testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
+ testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
+ testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
}
diff --git a/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt b/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt
new file mode 100644
index 00000000..fcf4f557
--- /dev/null
+++ b/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt
@@ -0,0 +1,63 @@
+package com.simiacryptus.skyenet
+
+import com.simiacryptus.openai.OpenAIClient
+import com.simiacryptus.skyenet.actors.SimpleActor
+import com.simiacryptus.skyenet.actors.opt.ActorOptimization
+import com.simiacryptus.skyenet.actors.opt.Expectation
+import org.slf4j.LoggerFactory
+import org.slf4j.event.Level
+
+object ActorOptTest {
+
+ private val log = LoggerFactory.getLogger(ActorOptTest::class.java)
+
+ @JvmStatic
+ fun main(args: Array) {
+ try {
+ ActorOptimization(
+ OpenAIClient(
+ logLevel = Level.DEBUG
+ )
+ ).runGeneticGenerations(
+ populationSize = 7,
+ generations = 5,
+ actorFactory = { SimpleActor(prompt = it) },
+ resultMapper = { it },
+ prompts = listOf(
+ """
+ |As the intermediary between the user and the search engine, your main task is to generate search queries based on user requests.
+ |Please respond to each user request by providing one or more calls to the "`search('query text')`" function.
+ |""".trimMargin(),
+ """
+ |You act as a bridge between the user and the search engine by creating search queries.
+ |Output one or more calls to "`search('query text')`" in response to each user request.
+ |""".trimMargin().trim(),
+ """
+ |You play the role of a search assistant.
+ |Provide one or more "`search('query text')`" calls as a response to each user request.
+ |Make sure to use single quotes around the query text.
+ |Surround the search function call with backticks.
+ |""".trimMargin().trim(),
+ ),
+ testCases = listOf(
+ ActorOptimization.TestCase(
+ userMessages = listOf(
+ "I want to buy a book.",
+ "A history book about Napoleon.",
+ ),
+ expectations = listOf(
+ Expectation.ContainsMatch("""`search\('.*?'\)`""".toRegex(), critical = false),
+ Expectation.ContainsMatch("""search\(.*?\)""".toRegex(), critical = false),
+ Expectation.VectorMatch("Great, what kind of book are you looking for?")
+ )
+ )
+ ),
+ )
+ } catch (e: Throwable) {
+ log.error("Error", e)
+ } finally {
+ System.exit(0)
+ }
+ }
+
+}
\ No newline at end of file