diff --git a/README.md b/README.md index 23568ed2..a453a8b1 100644 --- a/README.md +++ b/README.md @@ -76,18 +76,18 @@ Maven: com.simiacryptus skyenet-webui - 1.0.26 + 1.0.27 ``` Gradle: ```groovy -implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.26' +implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.27' ``` ```kotlin -implementation("com.simiacryptus:skyenet:1.0.26") +implementation("com.simiacryptus:skyenet:1.0.27") ``` ### 🌟 To Use diff --git a/core/build.gradle.kts b/core/build.gradle.kts index d3b1c28f..f81bdaca 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -26,17 +26,15 @@ kotlin { jvmToolchain(11) } -val kotlin_version = "1.9.20" -val junit_version = "5.9.2" -val jetty_version = "11.0.17" +val junit_version = "5.10.1" val logback_version = "1.2.12" dependencies { - implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27") + implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28") implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") - implementation(group = "commons-io", name = "commons-io", version = "2.11.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") compileOnlyApi(kotlin("stdlib")) implementation(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.7.3") @@ -48,15 +46,13 @@ dependencies { compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version) compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version) -// compileOnlyApi(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version) compileOnlyApi(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0") - compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454") + compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587") compileOnlyApi(group = "ch.qos.logback", name = "logback-classic", version = logback_version) compileOnlyApi(group = "ch.qos.logback", name = "logback-core", version = logback_version) -// testImplementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version) testImplementation(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0") - testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454") + testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587") testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version) testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version) diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt new file mode 100644 index 00000000..09dc2be9 --- /dev/null +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt @@ -0,0 +1,188 @@ +package com.simiacryptus.skyenet.actors.opt + +import com.simiacryptus.openai.OpenAIClient +import com.simiacryptus.skyenet.actors.opt.ActorOptimization.GeneticApi.Prompt +import com.simiacryptus.openai.proxy.ChatProxy +import com.simiacryptus.skyenet.actors.BaseActor +import com.simiacryptus.util.describe.Description +import org.slf4j.LoggerFactory +import kotlin.math.pow + +open class ActorOptimization( + val api: OpenAIClient, + val model: OpenAIClient.Models = OpenAIClient.Models.GPT35Turbo, + private val mutationRate: Double = 0.5, + private val mutatonTypes: Map = mapOf( + "Rephrase" to 1.0, + "Randomize" to 1.0, + "Summarize" to 1.0, + "Expand" to 1.0, + "Reorder" to 1.0, + "Remove Duplicate" to 1.0, + ) +) { + + data class TestCase( + val userMessages: List, + val expectations: List, + val retries: Int = 3 + ) + + open fun runGeneticGenerations( + prompts: List, + testCases: List, + actorFactory: (String) -> BaseActor, + resultMapper: (T) -> String, + selectionSize: Int = defaultSelectionSize(prompts), + populationSize: Int = defaultPositionSize(selectionSize, prompts), + generations: Int = 3 + ): List { + var topPrompts = regenerate(prompts, populationSize) + for (generation in 0..generations) { + val scores = topPrompts.map { prompt -> + prompt to testCases.map { testCase -> + val answer = actorFactory(prompt).answer(*testCase.userMessages.toTypedArray(), api = api) + testCase.expectations.map { it.score(api, resultMapper(answer)) }.average() + }.average() + } + scores.sortedByDescending { it.second }.forEach { + log.info("""Scored ${it.second}: ${it.first.replace("\n", "\\n")}""") + } + if (generation == generations) { + log.info("Final generation: ${topPrompts.first()}") + break + } else { + val survivors = scores.sortedByDescending { it.second }.take(selectionSize).map { it.first } + topPrompts = regenerate(survivors, populationSize) + log.info("Generation $generation: ${topPrompts.first()}") + } + } + return topPrompts + } + + private fun defaultPositionSize(selectionSize: Int, systemPrompts: List) = + Math.max(Math.max(selectionSize, 5), systemPrompts.size) + + private fun defaultSelectionSize(systemPrompts: List) = + Math.max(Math.ceil(Math.log((systemPrompts.size + 1).toDouble()) / Math.log(2.0)), 3.0) + .toInt() + + open fun regenerate(progenetors: List, desiredCount: Int): List { + val result = listOf().toMutableList() + result += progenetors + while (result.size < desiredCount) { + if (progenetors.size == 1) { + val selected = progenetors.first() + val mutated = mutate(selected) + result += mutated + } else if (progenetors.size == 0) { + throw RuntimeException("No survivors") + } else { + val a = progenetors.random() + var b: String + do { + b = progenetors.random() + } while (a == b) + val child = recombine(a, b) + result += child + } + } + return result + } + + open fun recombine(a: String, b: String): String { + val temperature = 0.3 + for (retry in 0..3) { + try { + val child = geneticApi(temperature.pow(1.0 / (retry + 1))).recombine(Prompt(a), Prompt(b)).prompt + if (child.contentEquals(a) || child.contentEquals(b)) { + log.info("Recombine failure: retry $retry") + continue + } + log.info( + "Recombined Prompts\n\t${ + a.replace("\n", "\n\t") + }\n\t-- + --\n\t${ + b.replace("\n", "\n\t") + }\n\t-- -> --\n\t${child.replace("\n", "\n\t")}" + ) + if (Math.random() < mutationRate) { + return mutate(child) + } else { + return child + } + } catch (e: Exception) { + log.warn("Failed to recombine $a + $b", e) + } + } + return mutate(a) + } + + open fun mutate(selected: String): String { + val temperature = 0.3 + for (retry in 0..10) { + try { + val directive = getMutationDirective() + val mutated = geneticApi(temperature.pow(1.0 / (retry + 1))).mutate(Prompt(selected), directive).prompt + if (mutated.contentEquals(selected)) { + log.info("Mutate failure $retry ($directive): ${selected.replace("\n", "\\n")}") + continue + } + log.info( + "Mutated Prompt\n\t${selected.replace("\n", "\n\t")}\n\t-- -> --\n\t${ + mutated.replace( + "\n", + "\n\t" + ) + }" + ) + return mutated + } catch (e: Exception) { + log.warn("Failed to mutate $selected", e) + } + } + throw RuntimeException("Failed to mutate $selected") + } + + open fun getMutationDirective(): String { + val fate = mutatonTypes.values.sum() * Math.random() + var cumulative = 0.0 + for ((key, value) in mutatonTypes) { + cumulative += value + if (fate < cumulative) { + return key + } + } + return mutatonTypes.keys.random() + } + + protected interface GeneticApi { + @Description("Mutate the given prompt; rephrase, make random edits, etc.") + fun mutate( + systemPrompt: Prompt, + directive: String = "Rephrase" + ): Prompt + + @Description("Recombine the given prompts to produce a third with about the same length; swap phrases, reword, etc.") + fun recombine( + systemPromptA: Prompt, + systemPromptB: Prompt + ): Prompt + + data class Prompt( + val prompt: String + ) + } + + protected open fun geneticApi(temperature: Double = 0.3) = ChatProxy( + clazz = GeneticApi::class.java, + api = api, + model = model, + temperature = temperature + ).create() + + companion object { + val log = LoggerFactory.getLogger(ActorOptimization::class.java) + } + +} diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt new file mode 100644 index 00000000..a58b0145 --- /dev/null +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt @@ -0,0 +1,68 @@ +package com.simiacryptus.skyenet.actors.opt + +import com.simiacryptus.openai.OpenAIClient +import com.simiacryptus.openai.opt.DistanceType +import org.slf4j.LoggerFactory + +abstract class Expectation { + companion object { + val log = LoggerFactory.getLogger(Expectation::class.java) + } + + open class VectorMatch(val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() { + override fun matches(api: OpenAIClient, response: String): Boolean { + return true + } + + override fun score(api: OpenAIClient, response: String): Double { + val contentEmbedding = createEmbedding(api, example) + val promptEmbedding = createEmbedding(api, response) + val distance = metric.distance(contentEmbedding, promptEmbedding) + log.info( + """Distance = $distance + | from "${example.replace("\n", "\\n")}" + | to "${response.replace("\n", "\\n")}" + """.trimMargin().trim() + ) + return -distance + } + + private fun createEmbedding(api: OpenAIClient, str: String) = api.createEmbedding( + OpenAIClient.EmbeddingRequest( + model = OpenAIClient.Models.AdaEmbedding.modelName, input = str + ) + ).data.first().embedding!! + } + + open class ContainsMatch( + val pattern: Regex, + val critical: Boolean = true + ) : Expectation() { + override fun matches(api: OpenAIClient, response: String): Boolean { + if (!critical) return true + return _matches(response) + } + override fun score(api: OpenAIClient, response: String): Double { + return if (_matches(response)) 1.0 else 0.0 + } + + private fun _matches(response: String?): Boolean { + if (pattern.containsMatchIn(response ?: "")) return true + log.info( + """Failed to match ${ + pattern.pattern.replace("\n", "\\n") + } in ${ + response?.replace("\n", "\\n") ?: "" + }""" + ) + return false + } + + } + + abstract fun matches(api: OpenAIClient, response: String): Boolean + + abstract fun score(api: OpenAIClient, response: String): Double + + +} \ No newline at end of file diff --git a/gradle.properties b/gradle.properties index 24a56afd..d4ede578 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,6 @@ # Gradle Releases -> https://github.com/gradle/gradle/releases libraryGroup = com.simiacryptus.skyenet -libraryVersion = 1.0.26 +libraryVersion = 1.0.27 gradleVersion = 7.6.1 # Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library diff --git a/groovy/build.gradle.kts b/groovy/build.gradle.kts index b9da3af8..8dcca396 100644 --- a/groovy/build.gradle.kts +++ b/groovy/build.gradle.kts @@ -37,10 +37,10 @@ dependencies { implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version) implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") - implementation(group = "commons-io", name = "commons-io", version = "2.11.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") } diff --git a/java/build.gradle.kts b/java/build.gradle.kts index 3680e1e6..973244a3 100644 --- a/java/build.gradle.kts +++ b/java/build.gradle.kts @@ -36,10 +36,10 @@ dependencies { implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version) implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") - implementation(group = "commons-io", name = "commons-io", version = "2.11.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") testImplementation(group = "org.jetbrains.kotlin", name = "kotlin-script-runtime", version = kotlin_version) } diff --git a/kotlin/build.gradle.kts b/kotlin/build.gradle.kts index cf18c74f..ac30dcf8 100644 --- a/kotlin/build.gradle.kts +++ b/kotlin/build.gradle.kts @@ -40,10 +40,10 @@ dependencies { implementation(kotlin("compiler-embeddable")) implementation(kotlin("scripting-compiler-embeddable")) - implementation(group = "commons-io", name = "commons-io", version = "2.11.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.4.11") diff --git a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt index 7d16cd6a..27c48e82 100644 --- a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt +++ b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt @@ -1,3 +1,5 @@ +@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + package com.simiacryptus.skyenet.heart import com.simiacryptus.skyenet.Heart @@ -77,11 +79,6 @@ open class KotlinInterpreter( val configuration = CompilerConfiguration().apply { put(CLIConfigurationKeys.MESSAGE_COLLECTOR_KEY, messageCollector) - //addKotlinSourceRoot(tempFile.absolutePath) -// configureJdkClasspathRoots() -// System.getProperty("java.class.path").split(File.pathSeparator).forEach { -// addJvmClasspathRoot(File(it)) -// } val k2JVMCompilerArguments = K2JVMCompilerArguments() k2JVMCompilerArguments.fragmentSources = arrayOf(tempFile.absolutePath) k2JVMCompilerArguments.classpath = System.getProperty("java.class.path") @@ -90,11 +87,9 @@ open class KotlinInterpreter( this.setupJvmSpecificArguments(k2JVMCompilerArguments) this.setupCommonArguments(k2JVMCompilerArguments) this.setupLanguageVersionSettings(k2JVMCompilerArguments) -// put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true) put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, k2JVMCompilerArguments.moduleName!!) } - // Create the compiler environment val environment = KotlinCoreEnvironment.createForProduction( parentDisposable = {}, configuration = configuration, @@ -110,8 +105,8 @@ open class KotlinInterpreter( if (errors.isEmpty()) null else RuntimeException( """ - |${errors.joinToString("\n") { "Error: " + it }} - |${warnings.joinToString("\n") { "Warning: " + it }} + |${errors.joinToString("\n") { "Error: $it" }} + |${warnings.joinToString("\n") { "Warning: $it" }} """.trimMargin()) } } catch (e: CompilationException) { @@ -184,7 +179,7 @@ open class KotlinInterpreter( defs.entrySet().forEach { (key, value) -> val uuid = storageMap.getOrPut(value) { UUID.randomUUID() } retrievalIndex.put(uuid, WeakReference(value)) - val fqClassName = KotlinInterpreter.javaClass.name.replace("$", ".") + val fqClassName = KotlinInterpreter::class.java.name.replace("$", ".") val typeStr = typeOf(value) out.add("val $key : $typeStr = $fqClassName.retrievalIndex.get(java.util.UUID.fromString(\"$uuid\"))?.get()!! as $typeStr\n") } @@ -192,7 +187,7 @@ open class KotlinInterpreter( return out.joinToString("\n") } - open fun typeOf(value: Object?): String { + open fun typeOf(value: Any?): String { if (value is java.lang.reflect.Proxy) { return value.javaClass.interfaces[0].name.replace("$", ".") + "?" } diff --git a/webui/build.gradle.kts b/webui/build.gradle.kts index 69b5b317..468cdbdd 100644 --- a/webui/build.gradle.kts +++ b/webui/build.gradle.kts @@ -28,11 +28,11 @@ kotlin { } val kotlin_version = "1.9.20" -val jetty_version = "11.0.17" -val jackson_version = "2.15.2" +val jetty_version = "11.0.18" +val jackson_version = "2.15.3" dependencies { - implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27") + implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28") implementation(project(":core")) testImplementation(project(":groovy")) @@ -61,15 +61,15 @@ dependencies { implementation(group = "com.google.api-client", name = "google-api-client", version = "1.35.2") implementation(group = "com.google.oauth-client", name = "google-oauth-client-jetty", version = "1.34.1") implementation(group = "com.google.apis", name = "google-api-services-oauth2", version = "v2-rev157-1.25.0") - implementation(group = "commons-io", name = "commons-io", version = "2.11.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") implementation(group = "commons-codec", name = "commons-codec", version = "1.16.0") implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") testImplementation(group = "org.slf4j", name = "slf4j-simple", version = "2.0.9") testImplementation(kotlin("script-runtime")) - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") } diff --git a/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt b/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt new file mode 100644 index 00000000..fcf4f557 --- /dev/null +++ b/webui/src/test/kotlin/com/simiacryptus/skyenet/ActorOptTest.kt @@ -0,0 +1,63 @@ +package com.simiacryptus.skyenet + +import com.simiacryptus.openai.OpenAIClient +import com.simiacryptus.skyenet.actors.SimpleActor +import com.simiacryptus.skyenet.actors.opt.ActorOptimization +import com.simiacryptus.skyenet.actors.opt.Expectation +import org.slf4j.LoggerFactory +import org.slf4j.event.Level + +object ActorOptTest { + + private val log = LoggerFactory.getLogger(ActorOptTest::class.java) + + @JvmStatic + fun main(args: Array) { + try { + ActorOptimization( + OpenAIClient( + logLevel = Level.DEBUG + ) + ).runGeneticGenerations( + populationSize = 7, + generations = 5, + actorFactory = { SimpleActor(prompt = it) }, + resultMapper = { it }, + prompts = listOf( + """ + |As the intermediary between the user and the search engine, your main task is to generate search queries based on user requests. + |Please respond to each user request by providing one or more calls to the "`search('query text')`" function. + |""".trimMargin(), + """ + |You act as a bridge between the user and the search engine by creating search queries. + |Output one or more calls to "`search('query text')`" in response to each user request. + |""".trimMargin().trim(), + """ + |You play the role of a search assistant. + |Provide one or more "`search('query text')`" calls as a response to each user request. + |Make sure to use single quotes around the query text. + |Surround the search function call with backticks. + |""".trimMargin().trim(), + ), + testCases = listOf( + ActorOptimization.TestCase( + userMessages = listOf( + "I want to buy a book.", + "A history book about Napoleon.", + ), + expectations = listOf( + Expectation.ContainsMatch("""`search\('.*?'\)`""".toRegex(), critical = false), + Expectation.ContainsMatch("""search\(.*?\)""".toRegex(), critical = false), + Expectation.VectorMatch("Great, what kind of book are you looking for?") + ) + ) + ), + ) + } catch (e: Throwable) { + log.error("Error", e) + } finally { + System.exit(0) + } + } + +} \ No newline at end of file