Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
acharneski committed Nov 24, 2023
1 parent 55394db commit 3ec61dc
Show file tree
Hide file tree
Showing 17 changed files with 118 additions and 62 deletions.
2 changes: 1 addition & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ val logback_version = "1.4.11"

dependencies {

implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.35")
implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.36")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")
implementation(group = "commons-io", name = "commons-io", version = "2.15.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.simiacryptus.jopenai.describe.TypeDescriber
import com.simiacryptus.jopenai.models.ChatModels
import com.simiacryptus.skyenet.core.Interpreter
import com.simiacryptus.skyenet.core.OutputInterceptor
import java.util.*
import javax.script.ScriptException
import kotlin.reflect.KClass

Expand All @@ -28,6 +29,7 @@ open class CodingActor(
val autoEvaluate: Boolean = false,
private val fixIterations: Int = 3,
private val fixRetries: Int = 2,
val runtimeSymbols: Map<String, Any> = mapOf()
) : BaseActor<CodingActor.CodeResult>(
prompt = "",
name = name,
Expand All @@ -52,19 +54,21 @@ open class CodingActor(
override val prompt: String
get() = if (symbols.isNotEmpty()) """
|You will translate natural language instructions into
|an implementation using ${interpreter.getLanguage()} and the script context.
|Use ``` code blocks labeled with ${interpreter.getLanguage()} where appropriate.
|Defined symbols include ${symbols.keys.joinToString(", ")}.
|The runtime context is described below:
|an implementation using ${language} and the script context.
|Use ``` code blocks labeled with ${language} where appropriate.
|
|Defined symbols include {${symbols.keys.joinToString(", ")}} described below:
|
|```${this.describer.markupLanguage}
|${this.apiDescription}
|```
|
|${details ?: ""}
|""".trimMargin().trim()
else """
|You will translate natural language instructions into
|an implementation using ${interpreter.getLanguage()} and the script context.
|Use ``` code blocks labeled with ${interpreter.getLanguage()} where appropriate.
|an implementation using ${language} and the script context.
|Use ``` code blocks labeled with ${language} where appropriate.
|
|${details ?: ""}
|""".trimMargin().trim()
Expand All @@ -77,9 +81,9 @@ open class CodingActor(
|""".trimMargin().trim()
}.joinToString("\n")

open val interpreter: Interpreter by lazy { interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols) }
open val interpreter: Interpreter by lazy { interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) }

protected val language: String by lazy { interpreter.getLanguage() }
val language: String by lazy { interpreter.getLanguage() }

override fun answer(vararg questions: String, api: API): CodeResult =
if (!autoEvaluate) answer(*chatMessages(*questions), api = api)
Expand All @@ -95,7 +99,7 @@ open class CodingActor(
api: API
): CodeResult =
if (!autoEvaluate) CodeResultImpl(*injectCodePrefix(messages, codePrefix), api = (api as OpenAIClient))
else answerWithAutoEval(*injectCodePrefix(messages, codePrefix), api = api).first
else answerWithAutoEval(*injectCodePrefix(messages, codePrefix), api = api, codePrefix=codePrefix).first

open fun answerWithAutoEval(
vararg messages: String,
Expand All @@ -105,9 +109,10 @@ open class CodingActor(

open fun answerWithAutoEval(
vararg messages: ChatMessage,
api: API
api: API,
codePrefix: String = ""
): Pair<CodeResult, ExecutionResult> {
var result = CodeResultImpl(*messages, api = (api as OpenAIClient))
var result = CodeResultImpl(*messages, api = (api as OpenAIClient), codePrefix = codePrefix)
var lastError: Throwable? = null
for (i in 0..fixIterations) try {
return result to result.run()
Expand All @@ -116,7 +121,7 @@ open class CodingActor(
result = run {
val respondWithCode = fixCommand(api, result.getCode(), ex, *messages, model = model)
val renderedResponse = getRenderedResponse(respondWithCode.second)
val codedInstruction = getCode(interpreter.getLanguage(), respondWithCode.second)
val codedInstruction = getCode(language, respondWithCode.second)
log.info("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
CodeResultImpl(*messages, codePrefix = codedInstruction, api = api)
Expand All @@ -125,7 +130,7 @@ open class CodingActor(
throw RuntimeException(
"""
|Failed to fix code. Last attempt:
|```${interpreter.getLanguage().lowercase()}
|```${language.lowercase()}
|${result.getCode()}
|```
|
Expand All @@ -150,7 +155,7 @@ open class CodingActor(
val codeBlocks = extractCodeBlocks(response)
for (codingAttempt in 0..fixRetries) {
val renderedResponse = getRenderedResponse(codeBlocks)
val codedInstruction = getCode(interpreter.getLanguage(), codeBlocks)
val codedInstruction = getCode(language, codeBlocks)
log.info("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
return validateAndFix(self, codedInstruction, codePrefix, api, messages, model) ?: continue
Expand All @@ -169,7 +174,7 @@ open class CodingActor(
var workingCode = initialCode
for (fixAttempt in 0..fixIterations) {
try {
val validate = interpreter.validate((codePrefix + "\n" + workingCode).trim())
val validate = interpreter.validate((codePrefix + "\n" + workingCode).sortCode())
if (validate != null) throw validate
log.info("Validation succeeded")
(self as CodeResultImpl)._status = CodeResult.Status.Success
Expand All @@ -179,7 +184,7 @@ open class CodingActor(
(self as CodeResultImpl)._status = CodeResult.Status.Correcting
val respondWithCode = fixCommand(api, workingCode, ex, *messages, model = model)
val response = getRenderedResponse(respondWithCode.second)
workingCode = getCode(interpreter.getLanguage(), respondWithCode.second)
workingCode = getCode(language, respondWithCode.second)
log.info("Response: \n\t${response.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin())
}
Expand Down Expand Up @@ -211,7 +216,7 @@ open class CodingActor(

private inner class CodeResultImpl(
vararg messages: ChatMessage,
codePrefix: String = "",
val codePrefix: String = "",
api: OpenAIClient,
) : CodeResult {
var _status = CodeResult.Status.Coding
Expand All @@ -238,7 +243,7 @@ open class CodingActor(
@JsonIgnore
override fun getCode(): String = impl

override fun run() = execute(getCode())
override fun run() = execute(codePrefix + "\n" + getCode().sortCode())
}

private fun injectCodePrefix(
Expand Down Expand Up @@ -363,6 +368,61 @@ open class CodingActor(
}
}

fun String.sortCode(bodyWrapper: (String) -> String = { it }): String {
val (imports, otherCode) = this.split("\n").partition { it.trim().startsWith("import ") }
return imports.distinct().sorted().joinToString("\n") + "\n\n" + bodyWrapper(otherCode.joinToString("\n"))
}

fun String.camelCase(locale: Locale = Locale.getDefault()): String {
val words = fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }
return words.first().lowercase(locale) + words.drop(1).joinToString("") {
it.replaceFirstChar { c ->
when {
c.isLowerCase() -> c.titlecase(locale)
else -> c.toString()
}
}
}
}

fun String.pascalCase(locale: Locale = Locale.getDefault()): String =
fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("") {
it.replaceFirstChar { c ->
when {
c.isLowerCase() -> c.titlecase(locale)
else -> c.toString()
}
}
}

// Detect changes in the case of the first letter and prepend a space
fun String.fromPascalCase(): String = buildString {
var lastChar = ' '
for (c in this@fromPascalCase) {
if (c.isUpperCase() && lastChar.isLowerCase()) append(' ')
append(c)
lastChar = c
}
}

fun String.upperSnakeCase(locale: Locale = Locale.getDefault()): String =
fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("_") {
it.replaceFirstChar { c ->
when {
c.isLowerCase() -> c.titlecase(locale)
else -> c.toString()
}
}
}.uppercase(locale)

fun String.imports(): List<String> {
return this.split("\n").filter { it.trim().startsWith("import ") }.distinct().sorted()
}

fun String.stripImports(): String {
return this.split("\n").filter { !it.trim().startsWith("import ") }.joinToString("\n")
}

}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@ package com.simiacryptus.skyenet.core.actors
import com.simiacryptus.jopenai.API
import com.simiacryptus.jopenai.OpenAIClient
import com.simiacryptus.jopenai.models.ChatModels
import com.simiacryptus.jopenai.models.OpenAITextModel
import com.simiacryptus.jopenai.proxy.ChatProxy
import java.util.function.Function

open class ParsedActor<T:Any>(
val parserClass: Class<out Function<String, T>>,
prompt: String,
val action: String? = null,
name: String? = parserClass.simpleName,
model: ChatModels = ChatModels.GPT35Turbo,
temperature: Double = 0.3,
) : BaseActor<ParsedResponse<T>>(
prompt = prompt,
name = action,
name = name,
model = model,
temperature = temperature,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ class CodingActorInterceptor(

override fun answerWithAutoEval(
vararg messages: ChatMessage,
api: API
api: API,
codePrefix: String,
) = functionInterceptor.wrap(messages.toList().toTypedArray()) {
inner.answerWithAutoEval(*messages, api = api)
inner.answerWithAutoEval(*messages, api = api, codePrefix = codePrefix)
}

override fun implement(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.simiacryptus.skyenet.core.actors.record
import com.simiacryptus.jopenai.API
import com.simiacryptus.jopenai.models.OpenAIModel
import com.simiacryptus.skyenet.core.actors.ParsedActor
import com.simiacryptus.skyenet.core.actors.ParsedResponse
import com.simiacryptus.skyenet.core.util.FunctionWrapper

class ParsedActorInterceptor<T:Any>(
Expand All @@ -12,7 +11,7 @@ class ParsedActorInterceptor<T:Any>(
) : ParsedActor<T>(
parserClass = inner.parserClass,
prompt = inner.prompt,
action = inner.action,
name = inner.name,
model = inner.model,
temperature = inner.temperature,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,19 @@ abstract class ActorTestBase<R : Any> {

open fun testRun() {
testCases.forEach { testCase ->
val answer = actor.answer(messages = arrayOf(
val messages = arrayOf(
ApiModel.ChatMessage(
role = com.simiacryptus.jopenai.ApiModel.Role.system,
content = actor.prompt.toContentList()
),
) + testCase.userMessages.toTypedArray(), api)
) + testCase.userMessages.toTypedArray()
val answer = answer(messages)
log.info("Answer: ${resultMapper(answer)}")
}
}

open fun answer(messages: Array<ApiModel.ChatMessage>): R = actor.answer(messages = messages, api)

companion object {
private val log = LoggerFactory.getLogger(ActorTestBase::class.java)
}
Expand Down
2 changes: 1 addition & 1 deletion webui/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ val jetty_version = "11.0.18"
val jackson_version = "2.15.3"
dependencies {

implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.35")
implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.36")

implementation(project(":core"))
testImplementation(project(":groovy"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package com.simiacryptus.skyenet.webui.application

import com.simiacryptus.skyenet.webui.session.SessionMessage
import com.simiacryptus.skyenet.webui.session.SocketManagerBase
import com.simiacryptus.jopenai.describe.Description
import com.simiacryptus.skyenet.webui.session.SessionTask
import java.util.function.Consumer

class ApplicationInterface(private val inner: ApplicationSocketManager) {
fun send(html: String) = inner.send(html)
@Description("Returns html for a link that will trigger the given handler when clicked.")
fun hrefLink(linkText: String, classname: String = """href-link""", handler: Consumer<Unit>) =
inner.hrefLink(linkText, classname, handler)

@Description("Returns html for a text input form that will trigger the given handler when submitted.")
fun textInput(handler: Consumer<String>): String =
inner.textInput(handler)

fun newMessage(
operationID: String = SocketManagerBase.randomID(),
spinner: String = SessionMessage.spinner,
cancelable: Boolean = false
): SessionMessage = inner.newMessage(operationID, spinner, cancelable)
fun newTask(
//cancelable: Boolean = false
): SessionTask = inner.newTask(false)

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import com.simiacryptus.skyenet.core.platform.Session
import com.simiacryptus.skyenet.core.platform.User
import com.simiacryptus.skyenet.webui.chat.ChatServer
import com.simiacryptus.skyenet.webui.servlet.*
import com.simiacryptus.skyenet.webui.session.SessionMessage
import com.simiacryptus.skyenet.webui.session.SocketManager
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.simiacryptus.skyenet.core.platform.DataStorage
import com.simiacryptus.skyenet.core.platform.Session
import com.simiacryptus.skyenet.core.platform.User
import com.simiacryptus.skyenet.webui.chat.ChatSocket
import com.simiacryptus.skyenet.webui.session.SessionMessage
import com.simiacryptus.skyenet.webui.session.SessionTask
import com.simiacryptus.skyenet.webui.session.SocketManagerBase
import java.util.function.Consumer

Expand Down Expand Up @@ -80,7 +80,7 @@ abstract class ApplicationSocketManager(
)

companion object {
val spinner: String get() = """<div>${SessionMessage.spinner}</div>"""
val spinner: String get() = """<div>${SessionTask.spinner}</div>"""
// val playButton: String get() = """<button class="play-button" data-id="$operationID">▶</button>"""
// val cancelButton: String get() = """<button class="cancel-button" data-id="$operationID">&times;</button>"""
// val regenButton: String get() = """<button class="regen-button" data-id="$operationID">♲</button>"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.simiacryptus.jopenai.models.ChatModels
import com.simiacryptus.jopenai.models.OpenAITextModel
import com.simiacryptus.skyenet.core.platform.Session
import com.simiacryptus.skyenet.webui.application.ApplicationServer
import com.simiacryptus.skyenet.webui.session.SessionMessage
import com.simiacryptus.skyenet.webui.session.SessionTask
import com.simiacryptus.skyenet.webui.session.SocketManagerBase
import com.simiacryptus.skyenet.webui.util.MarkdownUtil

Expand Down Expand Up @@ -42,7 +42,7 @@ open class ChatSocketManager(
override fun onRun(userMessage: String, socket: ChatSocket) {
var responseContents = divInitializer(cancelable = false)
responseContents += """<div class="user-message">${renderResponse(userMessage)}</div>"""
send("""$responseContents<div class="chat-response">${SessionMessage.spinner}</div>""")
send("""$responseContents<div class="chat-response">${SessionTask.spinner}</div>""")
messages += ApiModel.ChatMessage(ApiModel.Role.user, userMessage.toContentList())
try {
val response = api.chat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import org.slf4j.LoggerFactory
import java.awt.image.BufferedImage
import java.util.*

abstract class SessionMessage(
abstract class SessionTask(
private var buffer: MutableList<StringBuilder> = mutableListOf(),
private val spinner: String = SessionMessage.spinner
private val spinner: String = SessionTask.spinner
) {
val currentText: String
get() = buffer.filter { it.isNotBlank() }.joinToString("")
Expand Down Expand Up @@ -62,7 +62,7 @@ abstract class SessionMessage(
add("""<img src="${save("${UUID.randomUUID()}.png", image.toPng())}" />""")

companion object {
val log = LoggerFactory.getLogger(SessionMessage::class.java)
val log = LoggerFactory.getLogger(SessionTask::class.java)

const val spinner =
"""<div class="spinner-border" role="status"><span class="sr-only">Loading...</span></div>"""
Expand Down
Loading

0 comments on commit 3ec61dc

Please sign in to comment.