Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.27 - Added Actor optimizers #31

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ Maven:
<dependency>
<groupId>com.simiacryptus</groupId>
<artifactId>skyenet-webui</artifactId>
<version>1.0.26</version>
<version>1.0.27</version>
</dependency>
```

Gradle:

```groovy
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.26'
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.27'
```

```kotlin
implementation("com.simiacryptus:skyenet:1.0.26")
implementation("com.simiacryptus:skyenet:1.0.27")
```

### 🌟 To Use
Expand Down
14 changes: 5 additions & 9 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@ kotlin {
jvmToolchain(11)
}

val kotlin_version = "1.9.20"
val junit_version = "5.9.2"
val jetty_version = "11.0.17"
val junit_version = "5.10.1"
val logback_version = "1.2.12"

dependencies {

implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.27")
implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.28")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

compileOnlyApi(kotlin("stdlib"))
implementation(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.7.3")
Expand All @@ -48,15 +46,13 @@ dependencies {
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version)
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version)

// compileOnlyApi(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
compileOnlyApi(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
compileOnlyApi(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
compileOnlyApi(group = "ch.qos.logback", name = "logback-core", version = logback_version)

// testImplementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
testImplementation(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.587")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package com.simiacryptus.skyenet.actors.opt

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.skyenet.actors.opt.ActorOptimization.GeneticApi.Prompt
import com.simiacryptus.openai.proxy.ChatProxy
import com.simiacryptus.skyenet.actors.BaseActor
import com.simiacryptus.util.describe.Description
import org.slf4j.LoggerFactory
import kotlin.math.pow

open class ActorOptimization(
val api: OpenAIClient,
val model: OpenAIClient.Models = OpenAIClient.Models.GPT35Turbo,
private val mutationRate: Double = 0.5,
private val mutatonTypes: Map<String, Double> = mapOf(
"Rephrase" to 1.0,
"Randomize" to 1.0,
"Summarize" to 1.0,
"Expand" to 1.0,
"Reorder" to 1.0,
"Remove Duplicate" to 1.0,
)
) {

data class TestCase(
val userMessages: List<String>,
val expectations: List<Expectation>,
val retries: Int = 3
)

open fun <T:Any> runGeneticGenerations(
prompts: List<String>,
testCases: List<TestCase>,
actorFactory: (String) -> BaseActor<T>,
resultMapper: (T) -> String,
selectionSize: Int = defaultSelectionSize(prompts),
populationSize: Int = defaultPositionSize(selectionSize, prompts),
generations: Int = 3
): List<String> {
var topPrompts = regenerate(prompts, populationSize)
for (generation in 0..generations) {
val scores = topPrompts.map { prompt ->
prompt to testCases.map { testCase ->
val answer = actorFactory(prompt).answer(*testCase.userMessages.toTypedArray<String>(), api = api)
testCase.expectations.map { it.score(api, resultMapper(answer)) }.average()
}.average()
}
scores.sortedByDescending { it.second }.forEach {
log.info("""Scored ${it.second}: ${it.first.replace("\n", "\\n")}""")
}
if (generation == generations) {
log.info("Final generation: ${topPrompts.first()}")
break
} else {
val survivors = scores.sortedByDescending { it.second }.take(selectionSize).map { it.first }
topPrompts = regenerate(survivors, populationSize)
log.info("Generation $generation: ${topPrompts.first()}")
}
}
return topPrompts
}

private fun defaultPositionSize(selectionSize: Int, systemPrompts: List<String>) =
Math.max(Math.max(selectionSize, 5), systemPrompts.size)

private fun defaultSelectionSize(systemPrompts: List<String>) =
Math.max(Math.ceil(Math.log((systemPrompts.size + 1).toDouble()) / Math.log(2.0)), 3.0)
.toInt()

open fun regenerate(progenetors: List<String>, desiredCount: Int): List<String> {
val result = listOf<String>().toMutableList()
result += progenetors
while (result.size < desiredCount) {
if (progenetors.size == 1) {
val selected = progenetors.first()
val mutated = mutate(selected)
result += mutated
} else if (progenetors.size == 0) {
throw RuntimeException("No survivors")
} else {
val a = progenetors.random()
var b: String
do {
b = progenetors.random()
} while (a == b)
val child = recombine(a, b)
result += child
}
}
return result
}

open fun recombine(a: String, b: String): String {
val temperature = 0.3
for (retry in 0..3) {
try {
val child = geneticApi(temperature.pow(1.0 / (retry + 1))).recombine(Prompt(a), Prompt(b)).prompt
if (child.contentEquals(a) || child.contentEquals(b)) {
log.info("Recombine failure: retry $retry")
continue
}
log.info(
"Recombined Prompts\n\t${
a.replace("\n", "\n\t")
}\n\t-- + --\n\t${
b.replace("\n", "\n\t")
}\n\t-- -> --\n\t${child.replace("\n", "\n\t")}"
)
if (Math.random() < mutationRate) {
return mutate(child)
} else {
return child
}
} catch (e: Exception) {
log.warn("Failed to recombine $a + $b", e)
}
}
return mutate(a)
}

open fun mutate(selected: String): String {
val temperature = 0.3
for (retry in 0..10) {
try {
val directive = getMutationDirective()
val mutated = geneticApi(temperature.pow(1.0 / (retry + 1))).mutate(Prompt(selected), directive).prompt
if (mutated.contentEquals(selected)) {
log.info("Mutate failure $retry ($directive): ${selected.replace("\n", "\\n")}")
continue
}
log.info(
"Mutated Prompt\n\t${selected.replace("\n", "\n\t")}\n\t-- -> --\n\t${
mutated.replace(
"\n",
"\n\t"
)
}"
)
return mutated
} catch (e: Exception) {
log.warn("Failed to mutate $selected", e)
}
}
throw RuntimeException("Failed to mutate $selected")
}

open fun getMutationDirective(): String {
val fate = mutatonTypes.values.sum() * Math.random()
var cumulative = 0.0
for ((key, value) in mutatonTypes) {
cumulative += value
if (fate < cumulative) {
return key
}
}
return mutatonTypes.keys.random()
}

protected interface GeneticApi {
@Description("Mutate the given prompt; rephrase, make random edits, etc.")
fun mutate(
systemPrompt: Prompt,
directive: String = "Rephrase"
): Prompt

@Description("Recombine the given prompts to produce a third with about the same length; swap phrases, reword, etc.")
fun recombine(
systemPromptA: Prompt,
systemPromptB: Prompt
): Prompt

data class Prompt(
val prompt: String
)
}

protected open fun geneticApi(temperature: Double = 0.3) = ChatProxy(
clazz = GeneticApi::class.java,
api = api,
model = model,
temperature = temperature
).create()

companion object {
val log = LoggerFactory.getLogger(ActorOptimization::class.java)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.simiacryptus.skyenet.actors.opt

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.opt.DistanceType
import org.slf4j.LoggerFactory

abstract class Expectation {
companion object {
val log = LoggerFactory.getLogger(Expectation::class.java)
}

open class VectorMatch(val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() {
override fun matches(api: OpenAIClient, response: String): Boolean {
return true
}

override fun score(api: OpenAIClient, response: String): Double {
val contentEmbedding = createEmbedding(api, example)
val promptEmbedding = createEmbedding(api, response)
val distance = metric.distance(contentEmbedding, promptEmbedding)
log.info(
"""Distance = $distance
| from "${example.replace("\n", "\\n")}"
| to "${response.replace("\n", "\\n")}"
""".trimMargin().trim()
)
return -distance
}

private fun createEmbedding(api: OpenAIClient, str: String) = api.createEmbedding(
OpenAIClient.EmbeddingRequest(
model = OpenAIClient.Models.AdaEmbedding.modelName, input = str
)
).data.first().embedding!!
}

open class ContainsMatch(
val pattern: Regex,
val critical: Boolean = true
) : Expectation() {
override fun matches(api: OpenAIClient, response: String): Boolean {
if (!critical) return true
return _matches(response)
}
override fun score(api: OpenAIClient, response: String): Double {
return if (_matches(response)) 1.0 else 0.0
}

private fun _matches(response: String?): Boolean {
if (pattern.containsMatchIn(response ?: "")) return true
log.info(
"""Failed to match ${
pattern.pattern.replace("\n", "\\n")
} in ${
response?.replace("\n", "\\n") ?: ""
}"""
)
return false
}

}

abstract fun matches(api: OpenAIClient, response: String): Boolean

abstract fun score(api: OpenAIClient, response: String): Double


}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gradle Releases -> https://github.com/gradle/gradle/releases
libraryGroup = com.simiacryptus.skyenet
libraryVersion = 1.0.26
libraryVersion = 1.0.27
gradleVersion = 7.6.1

# Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library
Expand Down
6 changes: 3 additions & 3 deletions groovy/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ dependencies {
implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version)

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")

}

Expand Down
6 changes: 3 additions & 3 deletions java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ dependencies {

implementation(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version)
implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")
testImplementation(group = "org.jetbrains.kotlin", name = "kotlin-script-runtime", version = kotlin_version)

}
Expand Down
6 changes: 3 additions & 3 deletions kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ dependencies {
implementation(kotlin("compiler-embeddable"))
implementation(kotlin("scripting-compiler-embeddable"))

implementation(group = "commons-io", name = "commons-io", version = "2.11.0")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.9.2")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.9.2")
testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1")
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.4.11")
Expand Down
Loading
Loading