-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1.0.27 - Added Actor optimizers (#31)
* 1.0.27 - Added Actor optimizers * 1.0.27
- Loading branch information
1 parent
c7499d2
commit e87ceba
Showing
11 changed files
with
349 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
188 changes: 188 additions & 0 deletions
188
core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/ActorOptimization.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
package com.simiacryptus.skyenet.actors.opt | ||
|
||
import com.simiacryptus.openai.OpenAIClient | ||
import com.simiacryptus.skyenet.actors.opt.ActorOptimization.GeneticApi.Prompt | ||
import com.simiacryptus.openai.proxy.ChatProxy | ||
import com.simiacryptus.skyenet.actors.BaseActor | ||
import com.simiacryptus.util.describe.Description | ||
import org.slf4j.LoggerFactory | ||
import kotlin.math.pow | ||
|
||
open class ActorOptimization( | ||
val api: OpenAIClient, | ||
val model: OpenAIClient.Models = OpenAIClient.Models.GPT35Turbo, | ||
private val mutationRate: Double = 0.5, | ||
private val mutatonTypes: Map<String, Double> = mapOf( | ||
"Rephrase" to 1.0, | ||
"Randomize" to 1.0, | ||
"Summarize" to 1.0, | ||
"Expand" to 1.0, | ||
"Reorder" to 1.0, | ||
"Remove Duplicate" to 1.0, | ||
) | ||
) { | ||
|
||
data class TestCase( | ||
val userMessages: List<String>, | ||
val expectations: List<Expectation>, | ||
val retries: Int = 3 | ||
) | ||
|
||
open fun <T:Any> runGeneticGenerations( | ||
prompts: List<String>, | ||
testCases: List<TestCase>, | ||
actorFactory: (String) -> BaseActor<T>, | ||
resultMapper: (T) -> String, | ||
selectionSize: Int = defaultSelectionSize(prompts), | ||
populationSize: Int = defaultPositionSize(selectionSize, prompts), | ||
generations: Int = 3 | ||
): List<String> { | ||
var topPrompts = regenerate(prompts, populationSize) | ||
for (generation in 0..generations) { | ||
val scores = topPrompts.map { prompt -> | ||
prompt to testCases.map { testCase -> | ||
val answer = actorFactory(prompt).answer(*testCase.userMessages.toTypedArray<String>(), api = api) | ||
testCase.expectations.map { it.score(api, resultMapper(answer)) }.average() | ||
}.average() | ||
} | ||
scores.sortedByDescending { it.second }.forEach { | ||
log.info("""Scored ${it.second}: ${it.first.replace("\n", "\\n")}""") | ||
} | ||
if (generation == generations) { | ||
log.info("Final generation: ${topPrompts.first()}") | ||
break | ||
} else { | ||
val survivors = scores.sortedByDescending { it.second }.take(selectionSize).map { it.first } | ||
topPrompts = regenerate(survivors, populationSize) | ||
log.info("Generation $generation: ${topPrompts.first()}") | ||
} | ||
} | ||
return topPrompts | ||
} | ||
|
||
private fun defaultPositionSize(selectionSize: Int, systemPrompts: List<String>) = | ||
Math.max(Math.max(selectionSize, 5), systemPrompts.size) | ||
|
||
private fun defaultSelectionSize(systemPrompts: List<String>) = | ||
Math.max(Math.ceil(Math.log((systemPrompts.size + 1).toDouble()) / Math.log(2.0)), 3.0) | ||
.toInt() | ||
|
||
open fun regenerate(progenetors: List<String>, desiredCount: Int): List<String> { | ||
val result = listOf<String>().toMutableList() | ||
result += progenetors | ||
while (result.size < desiredCount) { | ||
if (progenetors.size == 1) { | ||
val selected = progenetors.first() | ||
val mutated = mutate(selected) | ||
result += mutated | ||
} else if (progenetors.size == 0) { | ||
throw RuntimeException("No survivors") | ||
} else { | ||
val a = progenetors.random() | ||
var b: String | ||
do { | ||
b = progenetors.random() | ||
} while (a == b) | ||
val child = recombine(a, b) | ||
result += child | ||
} | ||
} | ||
return result | ||
} | ||
|
||
open fun recombine(a: String, b: String): String { | ||
val temperature = 0.3 | ||
for (retry in 0..3) { | ||
try { | ||
val child = geneticApi(temperature.pow(1.0 / (retry + 1))).recombine(Prompt(a), Prompt(b)).prompt | ||
if (child.contentEquals(a) || child.contentEquals(b)) { | ||
log.info("Recombine failure: retry $retry") | ||
continue | ||
} | ||
log.info( | ||
"Recombined Prompts\n\t${ | ||
a.replace("\n", "\n\t") | ||
}\n\t-- + --\n\t${ | ||
b.replace("\n", "\n\t") | ||
}\n\t-- -> --\n\t${child.replace("\n", "\n\t")}" | ||
) | ||
if (Math.random() < mutationRate) { | ||
return mutate(child) | ||
} else { | ||
return child | ||
} | ||
} catch (e: Exception) { | ||
log.warn("Failed to recombine $a + $b", e) | ||
} | ||
} | ||
return mutate(a) | ||
} | ||
|
||
open fun mutate(selected: String): String { | ||
val temperature = 0.3 | ||
for (retry in 0..10) { | ||
try { | ||
val directive = getMutationDirective() | ||
val mutated = geneticApi(temperature.pow(1.0 / (retry + 1))).mutate(Prompt(selected), directive).prompt | ||
if (mutated.contentEquals(selected)) { | ||
log.info("Mutate failure $retry ($directive): ${selected.replace("\n", "\\n")}") | ||
continue | ||
} | ||
log.info( | ||
"Mutated Prompt\n\t${selected.replace("\n", "\n\t")}\n\t-- -> --\n\t${ | ||
mutated.replace( | ||
"\n", | ||
"\n\t" | ||
) | ||
}" | ||
) | ||
return mutated | ||
} catch (e: Exception) { | ||
log.warn("Failed to mutate $selected", e) | ||
} | ||
} | ||
throw RuntimeException("Failed to mutate $selected") | ||
} | ||
|
||
open fun getMutationDirective(): String { | ||
val fate = mutatonTypes.values.sum() * Math.random() | ||
var cumulative = 0.0 | ||
for ((key, value) in mutatonTypes) { | ||
cumulative += value | ||
if (fate < cumulative) { | ||
return key | ||
} | ||
} | ||
return mutatonTypes.keys.random() | ||
} | ||
|
||
protected interface GeneticApi { | ||
@Description("Mutate the given prompt; rephrase, make random edits, etc.") | ||
fun mutate( | ||
systemPrompt: Prompt, | ||
directive: String = "Rephrase" | ||
): Prompt | ||
|
||
@Description("Recombine the given prompts to produce a third with about the same length; swap phrases, reword, etc.") | ||
fun recombine( | ||
systemPromptA: Prompt, | ||
systemPromptB: Prompt | ||
): Prompt | ||
|
||
data class Prompt( | ||
val prompt: String | ||
) | ||
} | ||
|
||
protected open fun geneticApi(temperature: Double = 0.3) = ChatProxy( | ||
clazz = GeneticApi::class.java, | ||
api = api, | ||
model = model, | ||
temperature = temperature | ||
).create() | ||
|
||
companion object { | ||
val log = LoggerFactory.getLogger(ActorOptimization::class.java) | ||
} | ||
|
||
} |
68 changes: 68 additions & 0 deletions
68
core/src/main/kotlin/com/simiacryptus/skyenet/actors/opt/Expectation.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package com.simiacryptus.skyenet.actors.opt | ||
|
||
import com.simiacryptus.openai.OpenAIClient | ||
import com.simiacryptus.openai.opt.DistanceType | ||
import org.slf4j.LoggerFactory | ||
|
||
abstract class Expectation { | ||
companion object { | ||
val log = LoggerFactory.getLogger(Expectation::class.java) | ||
} | ||
|
||
open class VectorMatch(val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() { | ||
override fun matches(api: OpenAIClient, response: String): Boolean { | ||
return true | ||
} | ||
|
||
override fun score(api: OpenAIClient, response: String): Double { | ||
val contentEmbedding = createEmbedding(api, example) | ||
val promptEmbedding = createEmbedding(api, response) | ||
val distance = metric.distance(contentEmbedding, promptEmbedding) | ||
log.info( | ||
"""Distance = $distance | ||
| from "${example.replace("\n", "\\n")}" | ||
| to "${response.replace("\n", "\\n")}" | ||
""".trimMargin().trim() | ||
) | ||
return -distance | ||
} | ||
|
||
private fun createEmbedding(api: OpenAIClient, str: String) = api.createEmbedding( | ||
OpenAIClient.EmbeddingRequest( | ||
model = OpenAIClient.Models.AdaEmbedding.modelName, input = str | ||
) | ||
).data.first().embedding!! | ||
} | ||
|
||
open class ContainsMatch( | ||
val pattern: Regex, | ||
val critical: Boolean = true | ||
) : Expectation() { | ||
override fun matches(api: OpenAIClient, response: String): Boolean { | ||
if (!critical) return true | ||
return _matches(response) | ||
} | ||
override fun score(api: OpenAIClient, response: String): Double { | ||
return if (_matches(response)) 1.0 else 0.0 | ||
} | ||
|
||
private fun _matches(response: String?): Boolean { | ||
if (pattern.containsMatchIn(response ?: "")) return true | ||
log.info( | ||
"""Failed to match ${ | ||
pattern.pattern.replace("\n", "\\n") | ||
} in ${ | ||
response?.replace("\n", "\\n") ?: "" | ||
}""" | ||
) | ||
return false | ||
} | ||
|
||
} | ||
|
||
abstract fun matches(api: OpenAIClient, response: String): Boolean | ||
|
||
abstract fun score(api: OpenAIClient, response: String): Double | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.