From abb6713d655ccfa574fb7490802dc6a2b7f9112f Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Fri, 10 Nov 2023 07:32:25 -0500 Subject: [PATCH] 1.0.23 --- .../skyenet/actors/CodingActor.kt | 19 ++- kotlin/build.gradle.kts | 2 + .../skyenet/heart/KotlinInterpreter.kt | 130 +++++++++++++----- .../skyenet/KotlinInterpreterTest.kt | 17 +++ 4 files changed, 126 insertions(+), 42 deletions(-) diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/CodingActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/CodingActor.kt index e301acd5..cd8240c1 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/actors/CodingActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/actors/CodingActor.kt @@ -94,9 +94,17 @@ class CodingActor( override fun answer(vararg messages: OpenAIClient.ChatMessage): CodeResult { return CodeResultImpl(*messages) } + fun answerWithPrefix(codePrefix: String, vararg messages: OpenAIClient.ChatMessage): CodeResult { + val prevList = messages.toList() + val newList = prevList.dropLast(1) + listOf( + OpenAIClient.ChatMessage(OpenAIClient.ChatMessage.Role.assistant, codePrefix) + ) + prevList.last() + return CodeResultImpl(*newList.toTypedArray()) + } private inner class CodeResultImpl( - vararg messages: OpenAIClient.ChatMessage + vararg messages: OpenAIClient.ChatMessage, + codePrefix: String = "", ) : CodeResult { var _status = CodeResult.Status.Coding override fun getStatus(): CodeResult.Status { @@ -112,7 +120,7 @@ class CodingActor( describer = describer, model = model, temperature = temperature, - ), *messages + ), *messages, codePrefix = codePrefix ) if (_status != CodeResult.Status.Success) { codedInstruction = implement( @@ -123,7 +131,7 @@ class CodingActor( describer = describer, model = fallbackModel, temperature = temperature, - ), *messages + ), *messages, codePrefix = codePrefix ) } if (_status != CodeResult.Status.Success) { @@ -135,7 +143,8 @@ class CodingActor( private fun implement( brain: Brain, - vararg messages: OpenAIClient.ChatMessage + vararg messages: OpenAIClient.ChatMessage, + codePrefix: String = "", ): String { val response = brain.implement(*messages) val codeBlocks = Brain.extractCodeBlocks(response) @@ -147,7 +156,7 @@ class CodingActor( log.info("Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) for (fixAttempt in 0..fixIterations) { try { - val validate = interpreter.validate(codedInstruction) + val validate = interpreter.validate((codePrefix + "\n" + codedInstruction).trim()) if (validate != null) throw validate log.info("Validation succeeded") _status = CodeResult.Status.Success diff --git a/kotlin/build.gradle.kts b/kotlin/build.gradle.kts index 5df7d536..cf18c74f 100644 --- a/kotlin/build.gradle.kts +++ b/kotlin/build.gradle.kts @@ -34,6 +34,8 @@ dependencies { implementation(kotlin("stdlib")) implementation(kotlin("scripting-jsr223")) + implementation(kotlin("scripting-jvm")) + implementation(kotlin("scripting-jvm-host")) implementation(kotlin("script-runtime")) implementation(kotlin("compiler-embeddable")) implementation(kotlin("scripting-compiler-embeddable")) diff --git a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt index 8ac7d33b..6cbf5c0b 100644 --- a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt +++ b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/heart/KotlinInterpreter.kt @@ -2,25 +2,41 @@ package com.simiacryptus.skyenet.heart import com.simiacryptus.skyenet.Heart import org.jetbrains.kotlin.cli.common.CLIConfigurationKeys -import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot +import org.jetbrains.kotlin.cli.common.arguments.K2JVMCompilerArguments import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.cli.common.repl.KotlinJsr223JvmScriptEngineFactoryBase +import org.jetbrains.kotlin.cli.common.repl.ScriptArgsWithTypes +import org.jetbrains.kotlin.cli.common.setupCommonArguments +import org.jetbrains.kotlin.cli.common.setupLanguageVersionSettings import org.jetbrains.kotlin.cli.jvm.compiler.EnvironmentConfigFiles import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment import org.jetbrains.kotlin.cli.jvm.compiler.KotlinToJVMBytecodeCompiler -import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoot -import org.jetbrains.kotlin.cli.jvm.config.configureJdkClasspathRoots +import org.jetbrains.kotlin.cli.jvm.setupJvmSpecificArguments import org.jetbrains.kotlin.codegen.CompilationException import org.jetbrains.kotlin.codegen.state.GenerationState -import org.jetbrains.kotlin.com.intellij.openapi.Disposable import org.jetbrains.kotlin.config.CompilerConfiguration import java.io.File import java.lang.ref.WeakReference import java.util.* import java.util.Map +import javax.script.Bindings +import javax.script.ScriptContext +import javax.script.ScriptEngine +import kotlin.script.experimental.api.with +import kotlin.script.experimental.jsr223.KOTLIN_JSR223_RESOLVE_FROM_CLASSLOADER_PROPERTY +import kotlin.script.experimental.jsr223.KotlinJsr223DefaultScript import kotlin.script.experimental.jsr223.KotlinJsr223DefaultScriptEngineFactory - +import kotlin.script.experimental.jvm.JvmDependencyFromClassLoader +import kotlin.script.experimental.jvm.JvmScriptCompilationConfigurationBuilder +import kotlin.script.experimental.jvm.updateClasspath +import kotlin.script.experimental.jvm.util.scriptCompilationClasspathFromContext +import kotlin.script.experimental.jvmhost.createJvmScriptDefinitionFromTemplate +import kotlin.script.experimental.jvmhost.jsr223.KotlinJsr223ScriptEngineImpl +import kotlin.script.experimental.jvm.jvm +import kotlin.script.experimental.api.dependencies +import kotlin.script.experimental.host.ScriptDefinition open class KotlinInterpreter( private val defs: Map = HashMap() as Map @@ -38,19 +54,16 @@ open class KotlinInterpreter( val warnings = ArrayList() val messageCollector = object : MessageCollector { override fun clear() {} - override fun hasErrors() = false + override fun hasErrors() = errors.isNotEmpty() override fun report( severity: CompilerMessageSeverity, message: String, location: CompilerMessageSourceLocation? ) { - // Log the message with severity and location - val line = location!!.line - val column = location.column - val lineText = tempFile.readLines()[line - 1] - val carotText = " ".repeat(column) + "^" + val lineText = tempFile.readLines()[location!!.line - 1] + val carotText = " ".repeat(location.column - 1) + "^" val msg = """ - |$severity: $message at line ${line} column ${column} + |$message at line ${location.line} column ${location.column} | $lineText | $carotText """.trimMargin().trim() @@ -64,39 +77,42 @@ open class KotlinInterpreter( val configuration = CompilerConfiguration().apply { put(CLIConfigurationKeys.MESSAGE_COLLECTOR_KEY, messageCollector) - addKotlinSourceRoot(tempFile.absolutePath) - configureJdkClasspathRoots() - System.getProperty("java.class.path").split(File.pathSeparator).forEach { - addJvmClasspathRoot(File(it)) - } - - put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true) - put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, "ModuleName") + //addKotlinSourceRoot(tempFile.absolutePath) +// configureJdkClasspathRoots() +// System.getProperty("java.class.path").split(File.pathSeparator).forEach { +// addJvmClasspathRoot(File(it)) +// } + val k2JVMCompilerArguments = K2JVMCompilerArguments() + k2JVMCompilerArguments.fragmentSources = arrayOf(tempFile.absolutePath) + k2JVMCompilerArguments.classpath = System.getProperty("java.class.path") + k2JVMCompilerArguments.moduleName = "ModuleName" + k2JVMCompilerArguments.script = true + this.setupJvmSpecificArguments(k2JVMCompilerArguments) + this.setupCommonArguments(k2JVMCompilerArguments) + this.setupLanguageVersionSettings(k2JVMCompilerArguments) +// put(org.jetbrains.kotlin.config.JVMConfigurationKeys.COMPILE_JAVA, true) + put(org.jetbrains.kotlin.config.CommonConfigurationKeys.MODULE_NAME, k2JVMCompilerArguments.moduleName!!) } // Create the compiler environment val environment = KotlinCoreEnvironment.createForProduction( - parentDisposable = Disposable {}, + parentDisposable = {}, configuration = configuration, configFiles = EnvironmentConfigFiles.JVM_CONFIG_FILES ) - val bindingContext = KotlinToJVMBytecodeCompiler.analyze(environment) - ?: throw IllegalStateException("Binding context could not be initialized") return try { - if (bindingContext.isError()) bindingContext.error else { - val compileBunchOfSources: GenerationState? = - KotlinToJVMBytecodeCompiler.analyzeAndGenerate(environment) - if (null == compileBunchOfSources) { - Exception("Compilation failed") - } else { - if (errors.isEmpty()) null - else RuntimeException( - """ - |${errors.joinToString("\n") { "Error: " + it }} - |${warnings.joinToString("\n") { "Warning: " + it }} - """.trimMargin()) - } + val compileBunchOfSources: GenerationState? = + KotlinToJVMBytecodeCompiler.analyzeAndGenerate(environment) + if (null == compileBunchOfSources) { + Exception("Compilation failed") + } else { + if (errors.isEmpty()) null + else RuntimeException( + """ + |${errors.joinToString("\n") { "Error: " + it }} + |${warnings.joinToString("\n") { "Warning: " + it }} + """.trimMargin()) } } catch (e: CompilationException) { RuntimeException( @@ -110,6 +126,45 @@ open class KotlinInterpreter( } } + private val scriptDefinition: ScriptDefinition = createJvmScriptDefinitionFromTemplate() + private var lastClassLoader: ClassLoader? = null + private var lastClassPath: List? = null + + @Synchronized + private fun JvmScriptCompilationConfigurationBuilder.dependenciesFromCurrentContext() { + val currentClassLoader = Thread.currentThread().contextClassLoader + val classPath = if (lastClassLoader == null || lastClassLoader != currentClassLoader) { + scriptCompilationClasspathFromContext( + classLoader = currentClassLoader, + wholeClasspath = true, + unpackJarCollections = true + ).also { + lastClassLoader = currentClassLoader + lastClassPath = it + } + } else lastClassPath!! + updateClasspath(classPath) + } + + val compilationConfiguration = scriptDefinition.compilationConfiguration.with { + jvm { + if (System.getProperty(KOTLIN_JSR223_RESOLVE_FROM_CLASSLOADER_PROPERTY) == "true") { + dependencies(JvmDependencyFromClassLoader { Thread.currentThread().contextClassLoader }) + } else { + dependenciesFromCurrentContext() + } + } + } + + val scriptEngineFactory = object : KotlinJsr223JvmScriptEngineFactoryBase() { + override fun getScriptEngine(): ScriptEngine { + return KotlinJsr223ScriptEngineImpl( + this, + compilationConfiguration, + scriptDefinition.evaluationConfiguration + ) { ScriptArgsWithTypes(arrayOf(it.getBindings(ScriptContext.ENGINE_SCOPE).orEmpty()), arrayOf(Bindings::class)) } + } + } override fun run(code: String): Any? { val wrappedCode = wrapCode(code) log.info( @@ -117,7 +172,7 @@ open class KotlinInterpreter( | ${wrappedCode.trimIndent().replace("\n", "\n\t")} |""".trimMargin() ) - val scriptEngineFactory = KotlinJsr223DefaultScriptEngineFactory() + return scriptEngineFactory.scriptEngine.eval(wrappedCode) } @@ -125,6 +180,7 @@ open class KotlinInterpreter( val out = ArrayList() val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") } out.addAll(imports) + //out.add("import kotlin") defs.entrySet().forEach { (key, value) -> val uuid = storageMap.getOrPut(value) { UUID.randomUUID() } retrievalIndex.put(uuid, WeakReference(value)) diff --git a/kotlin/src/test/kotlin/com/simiacryptus/skyenet/KotlinInterpreterTest.kt b/kotlin/src/test/kotlin/com/simiacryptus/skyenet/KotlinInterpreterTest.kt index 2c5cbfc5..e21a6f02 100644 --- a/kotlin/src/test/kotlin/com/simiacryptus/skyenet/KotlinInterpreterTest.kt +++ b/kotlin/src/test/kotlin/com/simiacryptus/skyenet/KotlinInterpreterTest.kt @@ -1,7 +1,9 @@ package com.simiacryptus.skyenet.heart import com.simiacryptus.skyenet.HeartTestBase +import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test +import java.util.Map class KotlinInterpreterTest : HeartTestBase() { @@ -12,4 +14,19 @@ class KotlinInterpreterTest : HeartTestBase() { // TODO: This test is failing due to a bug with supplied primitives (e.g. Integer) } + + @Test + fun `test run with kotlin println`() { + val interpreter = newInterpreter(mapOf() as Map) + val result = interpreter.run("""println("Hello World")""") + Assertions.assertEquals(null, result) + } + + @Test + fun `test validate with kotlin println`() { + val interpreter = newInterpreter(mapOf() as Map) + val result = interpreter.validate("""println("Hello World")""") + Assertions.assertEquals(null, result) + } + } \ No newline at end of file