Skip to content

Commit

Permalink
1.0.23
Browse files Browse the repository at this point in the history
  • Loading branch information
acharneski committed Nov 10, 2023
1 parent c8e4beb commit abb6713
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,17 @@ class CodingActor(
override fun answer(vararg messages: OpenAIClient.ChatMessage): CodeResult {
return CodeResultImpl(*messages)
}
fun answerWithPrefix(codePrefix: String, vararg messages: OpenAIClient.ChatMessage): CodeResult {
val prevList = messages.toList()
val newList = prevList.dropLast(1) + listOf(
OpenAIClient.ChatMessage(OpenAIClient.ChatMessage.Role.assistant, codePrefix)
) + prevList.last()
return CodeResultImpl(*newList.toTypedArray())
}

private inner class CodeResultImpl(
vararg messages: OpenAIClient.ChatMessage
vararg messages: OpenAIClient.ChatMessage,
codePrefix: String = "",
) : CodeResult {
var _status = CodeResult.Status.Coding
override fun getStatus(): CodeResult.Status {
Expand All @@ -112,7 +120,7 @@ class CodingActor(
describer = describer,
model = model,
temperature = temperature,
), *messages
), *messages, codePrefix = codePrefix
)
if (_status != CodeResult.Status.Success) {
codedInstruction = implement(
Expand All @@ -123,7 +131,7 @@ class CodingActor(
describer = describer,
model = fallbackModel,
temperature = temperature,
), *messages
), *messages, codePrefix = codePrefix
)
}
if (_status != CodeResult.Status.Success) {
Expand All @@ -135,7 +143,8 @@ class CodingActor(

private fun implement(
brain: Brain,
vararg messages: OpenAIClient.ChatMessage
vararg messages: OpenAIClient.ChatMessage,
codePrefix: String = "",
): String {
val response = brain.implement(*messages)
val codeBlocks = Brain.extractCodeBlocks(response)
Expand All @@ -147,7 +156,7 @@ class CodingActor(
log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin())
for (fixAttempt in 0..fixIterations) {
try {
val validate = interpreter.validate(codedInstruction)
val validate = interpreter.validate((codePrefix + "\n" + codedInstruction).trim())
if (validate != null) throw validate
log.info("Validation succeeded")
_status = CodeResult.Status.Success
Expand Down
2 changes: 2 additions & 0 deletions kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies {

implementation(kotlin("stdlib"))
implementation(kotlin("scripting-jsr223"))
implementation(kotlin("scripting-jvm"))
implementation(kotlin("scripting-jvm-host"))
implementation(kotlin("script-runtime"))
implementation(kotlin("compiler-embeddable"))
implementation(kotlin("scripting-compiler-embeddable"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,41 @@ package com.simiacryptus.skyenet.heart

import com.simiacryptus.skyenet.Heart
import org.jetbrains.kotlin.cli.common.CLIConfigurationKeys
import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot
import org.jetbrains.kotlin.cli.common.arguments.K2JVMCompilerArguments
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.cli.common.repl.KotlinJsr223JvmScriptEngineFactoryBase
import org.jetbrains.kotlin.cli.common.repl.ScriptArgsWithTypes
import org.jetbrains.kotlin.cli.common.setupCommonArguments
import org.jetbrains.kotlin.cli.common.setupLanguageVersionSettings
import org.jetbrains.kotlin.cli.jvm.compiler.EnvironmentConfigFiles
import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment
import org.jetbrains.kotlin.cli.jvm.compiler.KotlinToJVMBytecodeCompiler
import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoot
import org.jetbrains.kotlin.cli.jvm.config.configureJdkClasspathRoots
import org.jetbrains.kotlin.cli.jvm.setupJvmSpecificArguments
import org.jetbrains.kotlin.codegen.CompilationException
import org.jetbrains.kotlin.codegen.state.GenerationState
import org.jetbrains.kotlin.com.intellij.openapi.Disposable
import org.jetbrains.kotlin.config.CompilerConfiguration
import java.io.File
import java.lang.ref.WeakReference
import java.util.*
import java.util.Map
import javax.script.Bindings
import javax.script.ScriptContext
import javax.script.ScriptEngine
import kotlin.script.experimental.api.with
import kotlin.script.experimental.jsr223.KOTLIN_JSR223_RESOLVE_FROM_CLASSLOADER_PROPERTY
import kotlin.script.experimental.jsr223.KotlinJsr223DefaultScript
import kotlin.script.experimental.jsr223.KotlinJsr223DefaultScriptEngineFactory

import kotlin.script.experimental.jvm.JvmDependencyFromClassLoader
import kotlin.script.experimental.jvm.JvmScriptCompilationConfigurationBuilder
import kotlin.script.experimental.jvm.updateClasspath
import kotlin.script.experimental.jvm.util.scriptCompilationClasspathFromContext
import kotlin.script.experimental.jvmhost.createJvmScriptDefinitionFromTemplate
import kotlin.script.experimental.jvmhost.jsr223.KotlinJsr223ScriptEngineImpl
import kotlin.script.experimental.jvm.jvm
import kotlin.script.experimental.api.dependencies
import kotlin.script.experimental.host.ScriptDefinition

open class KotlinInterpreter(
private val defs: Map<String, Object> = HashMap<String, Object>() as Map<String, Object>
Expand All @@ -38,19 +54,16 @@ open class KotlinInterpreter(
val warnings = ArrayList<String>()
val messageCollector = object : MessageCollector {
override fun clear() {}
override fun hasErrors() = false
override fun hasErrors() = errors.isNotEmpty()
override fun report(
severity: CompilerMessageSeverity,
message: String,
location: CompilerMessageSourceLocation?
) {
// Log the message with severity and location
val line = location!!.line
val column = location.column
val lineText = tempFile.readLines()[line - 1]
val carotText = " ".repeat(column) + "^"
val lineText = tempFile.readLines()[location!!.line - 1]
val carotText = " ".repeat(location.column - 1) + "^"
val msg = """
|$severity: $message at line ${line} column ${column}
|$message at line ${location.line} column ${location.column}
| $lineText
| $carotText
""".trimMargin().trim()
Expand All @@ -64,39 +77,42 @@ open class KotlinInterpreter(

val configuration = CompilerConfiguration().apply {
put(CLIConfigurationKeys.MESSAGE_COLLECTOR_KEY, messageCollector)
addKotlinSourceRoot(tempFile.absolutePath)
configureJdkClasspathRoots()
System.getProperty("java.class.path").split(File.pathSeparator).forEach {
addJvmClasspathRoot(File(it))
}

put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true)
put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, "ModuleName")
//addKotlinSourceRoot(tempFile.absolutePath)
// configureJdkClasspathRoots()
// System.getProperty("java.class.path").split(File.pathSeparator).forEach {
// addJvmClasspathRoot(File(it))
// }
val k2JVMCompilerArguments = K2JVMCompilerArguments()
k2JVMCompilerArguments.fragmentSources = arrayOf(tempFile.absolutePath)
k2JVMCompilerArguments.classpath = System.getProperty("java.class.path")
k2JVMCompilerArguments.moduleName = "ModuleName"
k2JVMCompilerArguments.script = true
this.setupJvmSpecificArguments(k2JVMCompilerArguments)
this.setupCommonArguments(k2JVMCompilerArguments)
this.setupLanguageVersionSettings(k2JVMCompilerArguments)
// put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true)
put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, k2JVMCompilerArguments.moduleName!!)
}

// Create the compiler environment
val environment = KotlinCoreEnvironment.createForProduction(
parentDisposable = Disposable {},
parentDisposable = {},
configuration = configuration,
configFiles = EnvironmentConfigFiles.JVM_CONFIG_FILES
)

val bindingContext = KotlinToJVMBytecodeCompiler.analyze(environment)
?: throw IllegalStateException("Binding context could not be initialized")
return try {
if (bindingContext.isError()) bindingContext.error else {
val compileBunchOfSources: GenerationState? =
KotlinToJVMBytecodeCompiler.analyzeAndGenerate(environment)
if (null == compileBunchOfSources) {
Exception("Compilation failed")
} else {
if (errors.isEmpty()) null
else RuntimeException(
"""
|${errors.joinToString("\n") { "Error: " + it }}
|${warnings.joinToString("\n") { "Warning: " + it }}
""".trimMargin())
}
val compileBunchOfSources: GenerationState? =
KotlinToJVMBytecodeCompiler.analyzeAndGenerate(environment)
if (null == compileBunchOfSources) {
Exception("Compilation failed")
} else {
if (errors.isEmpty()) null
else RuntimeException(
"""
|${errors.joinToString("\n") { "Error: " + it }}
|${warnings.joinToString("\n") { "Warning: " + it }}
""".trimMargin())
}
} catch (e: CompilationException) {
RuntimeException(
Expand All @@ -110,21 +126,61 @@ open class KotlinInterpreter(
}
}

private val scriptDefinition: ScriptDefinition = createJvmScriptDefinitionFromTemplate<KotlinJsr223DefaultScript>()
private var lastClassLoader: ClassLoader? = null
private var lastClassPath: List<File>? = null

@Synchronized
private fun JvmScriptCompilationConfigurationBuilder.dependenciesFromCurrentContext() {
val currentClassLoader = Thread.currentThread().contextClassLoader
val classPath = if (lastClassLoader == null || lastClassLoader != currentClassLoader) {
scriptCompilationClasspathFromContext(
classLoader = currentClassLoader,
wholeClasspath = true,
unpackJarCollections = true
).also {
lastClassLoader = currentClassLoader
lastClassPath = it
}
} else lastClassPath!!
updateClasspath(classPath)
}

val compilationConfiguration = scriptDefinition.compilationConfiguration.with {
jvm {
if (System.getProperty(KOTLIN_JSR223_RESOLVE_FROM_CLASSLOADER_PROPERTY) == "true") {
dependencies(JvmDependencyFromClassLoader { Thread.currentThread().contextClassLoader })
} else {
dependenciesFromCurrentContext()
}
}
}

val scriptEngineFactory = object : KotlinJsr223JvmScriptEngineFactoryBase() {
override fun getScriptEngine(): ScriptEngine {
return KotlinJsr223ScriptEngineImpl(
this,
compilationConfiguration,
scriptDefinition.evaluationConfiguration
) { ScriptArgsWithTypes(arrayOf(it.getBindings(ScriptContext.ENGINE_SCOPE).orEmpty()), arrayOf(Bindings::class)) }
}
}
override fun run(code: String): Any? {
val wrappedCode = wrapCode(code)
log.info(
"""Running:
| ${wrappedCode.trimIndent().replace("\n", "\n\t")}
|""".trimMargin()
)
val scriptEngineFactory = KotlinJsr223DefaultScriptEngineFactory()

return scriptEngineFactory.scriptEngine.eval(wrappedCode)
}

override fun wrapCode(code: String): String {
val out = ArrayList<String>()
val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") }
out.addAll(imports)
//out.add("import kotlin")
defs.entrySet().forEach { (key, value) ->
val uuid = storageMap.getOrPut(value) { UUID.randomUUID() }
retrievalIndex.put(uuid, WeakReference(value))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.simiacryptus.skyenet.heart

import com.simiacryptus.skyenet.HeartTestBase
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import java.util.Map

class KotlinInterpreterTest : HeartTestBase() {

Expand All @@ -12,4 +14,19 @@ class KotlinInterpreterTest : HeartTestBase() {
// TODO: This test is failing due to a bug with supplied primitives (e.g. Integer)
}


@Test
fun `test run with kotlin println`() {
val interpreter = newInterpreter(mapOf<String,Object>() as Map<String, Object>)
val result = interpreter.run("""println("Hello World")""")
Assertions.assertEquals(null, result)
}

@Test
fun `test validate with kotlin println`() {
val interpreter = newInterpreter(mapOf<String,Object>() as Map<String, Object>)
val result = interpreter.validate("""println("Hello World")""")
Assertions.assertEquals(null, result)
}

}

0 comments on commit abb6713

Please sign in to comment.