Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
acharneski committed Nov 25, 2023
1 parent 3ec61dc commit 0a30f7c
Show file tree
Hide file tree
Showing 24 changed files with 219 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.simiacryptus.skyenet.core
interface Interpreter {

fun getLanguage(): String
fun symbols() : Map<String, Any>
fun run(code: String): Any?
fun validate(code: String): Throwable?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import com.simiacryptus.skyenet.core.util.JsonFunctionRecorder
import java.io.File

open class ActorSystem<T:Enum<*>>(
private val actors: Map<T, BaseActor<*>>,
private val actors: Map<T, BaseActor<*,*>>,
val dataStorage: DataStorage,
val user: User?,
val session: Session
) {
private val sessionDir = dataStorage.getSessionDir(user, session)
fun getActor(actor: T): BaseActor<*> {
fun getActor(actor: T): BaseActor<*,*> {
val wrapper = getWrapper(actor.name)
return when (val baseActor = actors[actor]) {
null -> throw RuntimeException("No actor for $actor")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,28 @@
package com.simiacryptus.skyenet.core.actors

import com.fasterxml.jackson.annotation.JsonIgnore
import com.simiacryptus.jopenai.API
import com.simiacryptus.jopenai.ApiModel
import com.simiacryptus.jopenai.ClientUtil.toContentList
import com.simiacryptus.jopenai.OpenAIClient
import com.simiacryptus.jopenai.models.ChatModels
import com.simiacryptus.jopenai.models.OpenAIModel
import com.simiacryptus.jopenai.models.OpenAITextModel

abstract class BaseActor<T>(
abstract class BaseActor<I,R>(
open val prompt: String,
val name: String? = null,
val model: ChatModels = ChatModels.GPT35Turbo,
val temperature: Double = 0.3,
) {
abstract fun answer(vararg messages: ApiModel.ChatMessage, api: API): T
open fun response(vararg messages: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = (api as OpenAIClient).chat(
abstract fun answer(vararg messages: ApiModel.ChatMessage, input: I, api: API): R
open fun response(vararg input: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = (api as OpenAIClient).chat(
ApiModel.ChatRequest(
messages = ArrayList(messages.toList()),
messages = ArrayList(input.toList()),
temperature = temperature,
model = this.model.modelName,
),
model = this.model
)
open fun answer(vararg questions: String, api: API): T = answer(*chatMessages(*questions), api = api)
open fun answer(input: I, api: API): R = answer(*chatMessages(input), input=input, api = api)

fun chatMessages(vararg questions: String) = arrayOf(
ApiModel.ChatMessage(
role = com.simiacryptus.jopenai.ApiModel.Role.system,
content = prompt.toContentList()
),
) + questions.map {
ApiModel.ChatMessage(
role = com.simiacryptus.jopenai.ApiModel.Role.user,
content = it.toContentList()
)
}
abstract fun chatMessages(questions: I): Array<ApiModel.ChatMessage>

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,38 @@ open class CodingActor(
model: ChatModels = ChatModels.GPT35Turbo,
val fallbackModel: ChatModels = ChatModels.GPT4Turbo,
temperature: Double = 0.1,
val autoEvaluate: Boolean = false,
private val fixIterations: Int = 3,
private val fixRetries: Int = 2,
val runtimeSymbols: Map<String, Any> = mapOf()
) : BaseActor<CodingActor.CodeResult>(
) : BaseActor<CodingActor.CodeRequest, CodingActor.CodeResult>(
prompt = "",
name = name,
model = model,
temperature = temperature,
) {
val interpreter: Interpreter get() = interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols)

data class ExecutionResult(
val resultValue: String,
val resultOutput: String
data class CodeRequest(
val messages: List<String>,
val codePrefix: String = "",
val autoEvaluate: Boolean = false
)

interface CodeResult {
enum class Status {
Coding, Correcting, Success, Failure
}

fun getStatus(): Status
fun getCode(): String
fun run(): ExecutionResult
fun result(): ExecutionResult
}

data class ExecutionResult(
val resultValue: String,
val resultOutput: String
)

override val prompt: String
get() = if (symbols.isNotEmpty()) """
|You will translate natural language instructions into
Expand Down Expand Up @@ -81,50 +88,58 @@ open class CodingActor(
|""".trimMargin().trim()
}.joinToString("\n")

open val interpreter: Interpreter by lazy { interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) }

val language: String by lazy { interpreter.getLanguage() }

override fun answer(vararg questions: String, api: API): CodeResult =
if (!autoEvaluate) answer(*chatMessages(*questions), api = api)
else answerWithAutoEval(*chatMessages(*questions), api = api).first
override fun chatMessages(questions: CodeRequest): Array<ChatMessage> {
//injectCodePrefix
var chatMessages = arrayOf(
ChatMessage(
role = Role.system,
content = prompt.toContentList()
),
) + questions.messages.map {
ChatMessage(
role = Role.user,
content = it.toContentList()
)
}
if (questions.codePrefix.isNotBlank()) {
chatMessages = injectCodePrefix(chatMessages, questions.codePrefix)
}
return chatMessages

override fun answer(vararg messages: ChatMessage, api: API): CodeResult =
if (!autoEvaluate) CodeResultImpl(*messages, api = (api as OpenAIClient))
else answerWithAutoEval(*messages, api = api).first
}

open fun answerWithPrefix(
fun answerWithPrefix(
codePrefix: String,
vararg messages: ChatMessage,
api: API
): CodeResult =
if (!autoEvaluate) CodeResultImpl(*injectCodePrefix(messages, codePrefix), api = (api as OpenAIClient))
else answerWithAutoEval(*injectCodePrefix(messages, codePrefix), api = api, codePrefix=codePrefix).first

open fun answerWithAutoEval(
vararg messages: String,
api: API,
codePrefix: String = ""
) = answerWithAutoEval(*injectCodePrefix(chatMessages(*messages), codePrefix), api = api)
): CodeResult = answer(CodeRequest(
messages = messages.map { it.content?.first()?.text!! }.toList(),
codePrefix = codePrefix
), api = api)

open fun answerWithAutoEval(
override fun answer(
vararg messages: ChatMessage,
input: CodeRequest,
api: API,
codePrefix: String = ""
): Pair<CodeResult, ExecutionResult> {
var result = CodeResultImpl(*messages, api = (api as OpenAIClient), codePrefix = codePrefix)
): CodeResult {
var result = CodeResultImpl(*messages, api = (api as OpenAIClient), codePrefix = input.codePrefix)
if(!input.autoEvaluate) return result
var lastError: Throwable? = null
for (i in 0..fixIterations) try {
return result to result.run()
result.result()
return result
} catch (ex: Throwable) {
lastError = ex
result = run {
val respondWithCode = fixCommand(api, result.getCode(), ex, *messages, model = model)
val renderedResponse = getRenderedResponse(respondWithCode.second)
val codedInstruction = getCode(language, respondWithCode.second)
log.info("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
CodeResultImpl(*messages, codePrefix = codedInstruction, api = api)
log.info("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
CodeResultImpl(*messages, codePrefix = input.codePrefix, api = api, givenCode = codedInstruction)
}
}
throw RuntimeException(
Expand All @@ -138,32 +153,12 @@ open class CodingActor(
|```
|${lastError?.message}
|```
|""".trimMargin().trim()
|""".trimMargin().trim(), lastError
)
}

open fun implement(
self: CodeResult,
api: OpenAIClient,
messages: Array<out ChatMessage>,
codePrefix: String,
model: ChatModels
): String {
var request = ChatRequest()
request = request.copy(messages = ArrayList(messages.toList()))
val response = chat(api, request, this.model)
val codeBlocks = extractCodeBlocks(response)
for (codingAttempt in 0..fixRetries) {
val renderedResponse = getRenderedResponse(codeBlocks)
val codedInstruction = getCode(language, codeBlocks)
log.info("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
return validateAndFix(self, codedInstruction, codePrefix, api, messages, model) ?: continue
}
return ""
}

open fun validateAndFix(
private fun validateAndFix(
self: CodeResult,
initialCode: String,
codePrefix: String,
Expand All @@ -186,7 +181,7 @@ open class CodingActor(
val response = getRenderedResponse(respondWithCode.second)
workingCode = getCode(language, respondWithCode.second)
log.info("Response: \n\t${response.replace("\n", "\n\t", false)}".trimMargin())
log.info("Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin())
log.info("New Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin())
}
}
return null
Expand Down Expand Up @@ -218,13 +213,14 @@ open class CodingActor(
vararg messages: ChatMessage,
val codePrefix: String = "",
api: OpenAIClient,
val givenCode: String? = null,
) : CodeResult {
var _status = CodeResult.Status.Coding
override fun getStatus(): CodeResult.Status {
return _status
}

private val impl by lazy {
override fun getStatus() = _status

private val _code by lazy {
if(null != givenCode) return@lazy givenCode
var codedInstruction = implement(
this, api, messages, codePrefix = codePrefix, model
)
Expand All @@ -239,11 +235,31 @@ open class CodingActor(
}
codedInstruction
}
private fun implement(
self: CodeResult,
api: OpenAIClient,
messages: Array<out ChatMessage>,
codePrefix: String,
model: ChatModels
): String {
val request = ChatRequest(messages = ArrayList(messages.toList()))
val response = chat(api, request, model)
val codeBlocks = extractCodeBlocks(response)
for (codingAttempt in 0..fixRetries) {
val renderedResponse = getRenderedResponse(codeBlocks)
val codedInstruction = getCode(language, codeBlocks)
log.info("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin())
log.info("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
return validateAndFix(self, codedInstruction, codePrefix, api, messages, model) ?: continue
}
return ""
}

@JsonIgnore
override fun getCode(): String = impl
override fun getCode(): String = _code

override fun run() = execute(codePrefix + "\n" + getCode().sortCode())
private val executionResult by lazy { execute((codePrefix + "\n" + getCode()).sortCode()) }
override fun result() = executionResult
}

private fun injectCodePrefix(
Expand Down Expand Up @@ -283,7 +299,8 @@ open class CodingActor(
|Correct the code and try again.
|""".trimMargin().trim().toContentList()
)
))
)
)
)
val response = chat(api, request, model)
val codeBlocks = extractCodeBlocks(response)
Expand Down Expand Up @@ -426,4 +443,3 @@ open class CodingActor(
}

}

Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.simiacryptus.skyenet.core.actors

import com.simiacryptus.jopenai.API
import com.simiacryptus.jopenai.ApiModel
import com.simiacryptus.jopenai.ApiModel.ChatMessage
import com.simiacryptus.jopenai.ApiModel.ImageGenerationRequest
import com.simiacryptus.jopenai.ClientUtil.toContentList
import com.simiacryptus.jopenai.OpenAIClient
import com.simiacryptus.jopenai.models.ChatModels
import com.simiacryptus.jopenai.models.ImageModels
Expand All @@ -16,12 +18,24 @@ open class ImageActor(
temperature: Double = 0.3,
val width: Int = 1024,
val height: Int = 1024,
) : BaseActor<ImageResponse>(
) : BaseActor<List<String>, ImageResponse>(
prompt = prompt,
name = name,
model = textModel,
temperature = temperature,
) {
override fun chatMessages(questions: List<String>) = arrayOf(
ChatMessage(
role = ApiModel.Role.system,
content = prompt.toContentList()
),
) + questions.map {
ChatMessage(
role = ApiModel.Role.user,
content = it.toContentList()
)
}

private inner class ImageResponseImpl(vararg messages: ChatMessage, val api: API) : ImageResponse {

private val _text: String by lazy { response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") }
Expand All @@ -38,7 +52,7 @@ open class ImageActor(
}
}

override fun answer(vararg messages: ChatMessage, api: API): ImageResponse {
override fun answer(vararg messages: ChatMessage, input: List<String>, api: API): ImageResponse {
return ImageResponseImpl(*messages, api = api)
}
}
Expand Down
Loading

0 comments on commit 0a30f7c

Please sign in to comment.