Skip to content

Commit

Permalink
1.0.27 - Added Actor optimizers (#31)
Browse files Browse the repository at this point in the history
* 1.0.27 - Added Actor optimizers

* 1.0.27
  • Loading branch information
acharneski authored Nov 13, 2023
1 parent c7499d2 commit e87ceba
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 39 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ Maven:
<dependency>
<groupId>com.simiacryptus</groupId>
<artifactId>skyenet-webui</artifactId>
<version>1.0.26</version>
<version>1.0.27</version>
</dependency>
```

Gradle:

```groovy
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.26'
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.27'
```

```kotlin
implementation("com.simiacryptus:skyenet:1.0.26")
implementation("com.simiacryptus:skyenet:1.0.27")
```

### 🌟 To Use
Expand Down
14 changes: 5 additions & 9 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@ kotlin {
jvmToolchain(11)
}

val kotlin_version = "1.9.20"
val junit_version = "5.9.2"
val jetty_version = "11.0.17"
val junit_version = "5.10.1"
val logback_version = "1.2.12"

dependencies {

implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27")
implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

compileOnlyApi(kotlin("stdlib"))
implementation(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.7.3")
Expand All @@ -48,15 +46,13 @@ dependencies {
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version)
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version)

// compileOnlyApi(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
compileOnlyApi(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
compileOnlyApi(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
compileOnlyApi(group = "ch.qos.logback", name = "logback-core", version = logback_version)

// testImplementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
testImplementation(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package com.simiacryptus.skyenet.actors.opt

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.skyenet.actors.opt.ActorOptimization.GeneticApi.Prompt
import com.simiacryptus.openai.proxy.ChatProxy
import com.simiacryptus.skyenet.actors.BaseActor
import com.simiacryptus.util.describe.Description
import org.slf4j.LoggerFactory
import kotlin.math.pow

open class ActorOptimization(
val api: OpenAIClient,
val model: OpenAIClient.Models = OpenAIClient.Models.GPT35Turbo,
private val mutationRate: Double = 0.5,
private val mutatonTypes: Map<String, Double> = mapOf(
"Rephrase" to 1.0,
"Randomize" to 1.0,
"Summarize" to 1.0,
"Expand" to 1.0,
"Reorder" to 1.0,
"Remove Duplicate" to 1.0,
)
) {

data class TestCase(
val userMessages: List<String>,
val expectations: List<Expectation>,
val retries: Int = 3
)

open fun <T:Any> runGeneticGenerations(
prompts: List<String>,
testCases: List<TestCase>,
actorFactory: (String) -> BaseActor<T>,
resultMapper: (T) -> String,
selectionSize: Int = defaultSelectionSize(prompts),
populationSize: Int = defaultPositionSize(selectionSize, prompts),
generations: Int = 3
): List<String> {
var topPrompts = regenerate(prompts, populationSize)
for (generation in 0..generations) {
val scores = topPrompts.map { prompt ->
prompt to testCases.map { testCase ->
val answer = actorFactory(prompt).answer(*testCase.userMessages.toTypedArray<String>(), api = api)
testCase.expectations.map { it.score(api, resultMapper(answer)) }.average()
}.average()
}
scores.sortedByDescending { it.second }.forEach {
log.info("""Scored ${it.second}: ${it.first.replace("\n", "\\n")}""")
}
if (generation == generations) {
log.info("Final generation: ${topPrompts.first()}")
break
} else {
val survivors = scores.sortedByDescending { it.second }.take(selectionSize).map { it.first }
topPrompts = regenerate(survivors, populationSize)
log.info("Generation $generation: ${topPrompts.first()}")
}
}
return topPrompts
}

private fun defaultPositionSize(selectionSize: Int, systemPrompts: List<String>) =
Math.max(Math.max(selectionSize, 5), systemPrompts.size)

private fun defaultSelectionSize(systemPrompts: List<String>) =
Math.max(Math.ceil(Math.log((systemPrompts.size + 1).toDouble()) / Math.log(2.0)), 3.0)
.toInt()

open fun regenerate(progenetors: List<String>, desiredCount: Int): List<String> {
val result = listOf<String>().toMutableList()
result += progenetors
while (result.size < desiredCount) {
if (progenetors.size == 1) {
val selected = progenetors.first()
val mutated = mutate(selected)
result += mutated
} else if (progenetors.size == 0) {
throw RuntimeException("No survivors")
} else {
val a = progenetors.random()
var b: String
do {
b = progenetors.random()
} while (a == b)
val child = recombine(a, b)
result += child
}
}
return result
}

open fun recombine(a: String, b: String): String {
val temperature = 0.3
for (retry in 0..3) {
try {
val child = geneticApi(temperature.pow(1.0 / (retry + 1))).recombine(Prompt(a), Prompt(b)).prompt
if (child.contentEquals(a) || child.contentEquals(b)) {
log.info("Recombine failure: retry $retry")
continue
}
log.info(
"Recombined Prompts\n\t${
a.replace("\n", "\n\t")
}\n\t-- + --\n\t${
b.replace("\n", "\n\t")
}\n\t-- -> --\n\t${child.replace("\n", "\n\t")}"
)
if (Math.random() < mutationRate) {
return mutate(child)
} else {
return child
}
} catch (e: Exception) {
log.warn("Failed to recombine $a + $b", e)
}
}
return mutate(a)
}

open fun mutate(selected: String): String {
val temperature = 0.3
for (retry in 0..10) {
try {
val directive = getMutationDirective()
val mutated = geneticApi(temperature.pow(1.0 / (retry + 1))).mutate(Prompt(selected), directive).prompt
if (mutated.contentEquals(selected)) {
log.info("Mutate failure $retry ($directive): ${selected.replace("\n", "\\n")}")
continue
}
log.info(
"Mutated Prompt\n\t${selected.replace("\n", "\n\t")}\n\t-- -> --\n\t${
mutated.replace(
"\n",
"\n\t"
)
}"
)
return mutated
} catch (e: Exception) {
log.warn("Failed to mutate $selected", e)
}
}
throw RuntimeException("Failed to mutate $selected")
}

open fun getMutationDirective(): String {
val fate = mutatonTypes.values.sum() * Math.random()
var cumulative = 0.0
for ((key, value) in mutatonTypes) {
cumulative += value
if (fate < cumulative) {
return key
}
}
return mutatonTypes.keys.random()
}

protected interface GeneticApi {
@Description("Mutate the given prompt; rephrase, make random edits, etc.")
fun mutate(
systemPrompt: Prompt,
directive: String = "Rephrase"
): Prompt

@Description("Recombine the given prompts to produce a third with about the same length; swap phrases, reword, etc.")
fun recombine(
systemPromptA: Prompt,
systemPromptB: Prompt
): Prompt

data class Prompt(
val prompt: String
)
}

protected open fun geneticApi(temperature: Double = 0.3) = ChatProxy(
clazz = GeneticApi::class.java,
api = api,
model = model,
temperature = temperature
).create()

companion object {
val log = LoggerFactory.getLogger(ActorOptimization::class.java)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.simiacryptus.skyenet.actors.opt

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.opt.DistanceType
import org.slf4j.LoggerFactory

abstract class Expectation {
companion object {
val log = LoggerFactory.getLogger(Expectation::class.java)
}

open class VectorMatch(val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() {
override fun matches(api: OpenAIClient, response: String): Boolean {
return true
}

override fun score(api: OpenAIClient, response: String): Double {
val contentEmbedding = createEmbedding(api, example)
val promptEmbedding = createEmbedding(api, response)
val distance = metric.distance(contentEmbedding, promptEmbedding)
log.info(
"""Distance = $distance
| from "${example.replace("\n", "\\n")}"
| to "${response.replace("\n", "\\n")}"
""".trimMargin().trim()
)
return -distance
}

private fun createEmbedding(api: OpenAIClient, str: String) = api.createEmbedding(
OpenAIClient.EmbeddingRequest(
model = OpenAIClient.Models.AdaEmbedding.modelName, input = str
)
).data.first().embedding!!
}

open class ContainsMatch(
val pattern: Regex,
val critical: Boolean = true
) : Expectation() {
override fun matches(api: OpenAIClient, response: String): Boolean {
if (!critical) return true
return _matches(response)
}
override fun score(api: OpenAIClient, response: String): Double {
return if (_matches(response)) 1.0 else 0.0
}

private fun _matches(response: String?): Boolean {
if (pattern.containsMatchIn(response ?: "")) return true
log.info(
"""Failed to match ${
pattern.pattern.replace("\n", "\\n")
} in ${
response?.replace("\n", "\\n") ?: ""
}"""
)
return false
}

}

abstract fun matches(api: OpenAIClient, response: String): Boolean

abstract fun score(api: OpenAIClient, response: String): Double


}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gradle Releases -> https://github.com/gradle/gradle/releases
libraryGroup = com.simiacryptus.skyenet
libraryVersion = 1.0.26
libraryVersion = 1.0.27
gradleVersion = 7.6.1

# Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library
Expand Down
6 changes: 3 additions & 3 deletions groovy/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ dependencies {
implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version)

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")

}

Expand Down
6 changes: 3 additions & 3 deletions java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ dependencies {

implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version)
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
testImplementation(group = "org.jetbrains.kotlin", name = "kotlin-script-runtime", version = kotlin_version)

}
Expand Down
6 changes: 3 additions & 3 deletions kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ dependencies {
implementation(kotlin("compiler-embeddable"))
implementation(kotlin("scripting-compiler-embeddable"))

implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.4.11")
Expand Down
Loading

0 comments on commit e87ceba

Please sign in to comment.