From 5b1ef560a9d93e70d54fe4801aa84aaa3b34dcef Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Tue, 16 Apr 2024 22:41:48 -0400 Subject: [PATCH] 1.0.63 (#68) * 1.0.63 * docs * wip * wip * Update InterpreterTestBase.kt --- core/build.gradle.kts | 2 +- core/src/main/dev_documentation.md | 2 +- .../aicoder/actions/generic/CommonRoot.kt | 38 +- .../actions/generic/GetModuleRootForFile.kt | 14 +- .../skyenet/core/actors/ActorSystem.kt | 98 +- .../skyenet/core/actors/BaseActor.kt | 24 +- .../skyenet/core/actors/CodingActor.kt | 769 +-- .../skyenet/core/actors/ImageActor.kt | 2 +- .../skyenet/core/actors/MultiExeption.kt | 5 +- .../skyenet/core/actors/ParsedActor.kt | 194 +- .../skyenet/core/actors/SimpleActor.kt | 6 +- .../skyenet/core/actors/TextToSpeechActor.kt | 88 +- .../core/actors/opt/ActorOptimization.kt | 16 +- .../skyenet/core/actors/opt/Expectation.kt | 4 +- .../actors/record/CodingActorInterceptor.kt | 10 +- .../actors/record/ImageActorInterceptor.kt | 8 +- .../actors/record/ParsedActorInterceptor.kt | 42 +- .../actors/record/SimpleActorInterceptor.kt | 8 +- .../record/TextToSpeechActorInterceptor.kt | 20 +- .../skyenet/core/actors/test/ActorTestBase.kt | 4 +- .../core/actors/test/CodingActorTestBase.kt | 6 +- .../core/actors/test/ImageActorTestBase.kt | 2 +- .../core/platform/ApplicationServices.kt | 526 +- .../skyenet/core/platform/AwsPlatform.kt | 130 +- .../skyenet/core/platform/ClientManager.kt | 219 +- .../skyenet/core/platform/file/DataStorage.kt | 49 +- .../core/platform/file/UsageManager.kt | 345 +- .../test/AuthenticationInterfaceTest.kt | 70 +- .../test/AuthorizationInterfaceTest.kt | 18 +- .../skyenet/core/platform/test/UsageTest.kt | 40 +- .../core/platform/test/UserSettingsTest.kt | 64 +- .../core/util/ClasspathRelationships.kt | 713 ++- .../simiacryptus/skyenet/core/util/Ears.kt | 4 +- .../skyenet/core/util/FunctionWrapper.kt | 67 +- .../skyenet/core/util/RuleTreeBuilder.kt | 210 +- .../skyenet/core/util/Selenium.kt | 10 +- .../skyenet/core/util/StringSplitter.kt | 51 +- .../skyenet/interpreter/Interpreter.kt | 2 + .../interpreter/InterpreterTestBase.kt | 4 +- .../skyenet/core/actors/ActorOptTest.kt | 80 +- .../platform/AuthenticationManagerTest.kt | 2 +- .../core/platform/AuthorizationManagerTest.kt | 2 +- .../skyenet/core/util/RuleTreeBuilderTest.kt | 74 +- .../core_user_documentation.md | 0 .../webui_documentation.md | 4 +- gradle.properties | 2 +- .../skyenet/groovy/GroovyInterpreter.kt | 6 +- .../skyenet/groovy/GroovyInterpreterTest.kt | 3 +- .../skyenet/kotlin/KotlinInterpreter.kt | 218 +- webui/build.gradle.kts | 2 +- webui/src/compiled_documentation.md | 2 +- .../simiacryptus/diff/AddApplyDiffLinks.kt | 14 +- .../diff/AddApplyFileDiffLinks.kt | 154 +- .../github/simiacryptus/diff/AddSaveLinks.kt | 54 +- .../github/simiacryptus/diff/ApxPatchUtil.kt | 199 +- .../simiacryptus/diff/DiffMatchPatch.kt | 4546 ++++++++--------- .../com/github/simiacryptus/diff/DiffUtil.kt | 207 +- .../simiacryptus/diff/IterativePatchUtil.kt | 30 +- .../com/simiacryptus/skyenet/Acceptable.kt | 290 +- .../com/simiacryptus/skyenet/AgentPatterns.kt | 64 +- .../com/simiacryptus/skyenet/Retryable.kt | 36 +- .../com/simiacryptus/skyenet/TabbedDisplay.kt | 102 +- .../skyenet/apps/coding/CodingAgent.kt | 526 +- .../skyenet/apps/coding/CodingSubAgent.kt | 69 - .../skyenet/apps/coding/ShellToolAgent.kt | 799 +-- .../skyenet/apps/coding/ToolAgent.kt | 681 +-- .../skyenet/apps/general/WebDevApp.kt | 699 ++- .../skyenet/interpreter/ProcessInterpreter.kt | 79 +- .../webui/application/ApplicationDirectory.kt | 369 +- .../webui/application/ApplicationInterface.kt | 60 +- .../application/ApplicationSocketManager.kt | 66 +- .../skyenet/webui/chat/ChatServer.kt | 5 +- .../skyenet/webui/chat/ChatSocket.kt | 3 +- .../skyenet/webui/chat/ChatSocketManager.kt | 109 +- .../skyenet/webui/servlet/ApiKeyServlet.kt | 329 +- .../webui/servlet/CancelThreadsServlet.kt | 82 +- .../skyenet/webui/servlet/CorsFilter.kt | 5 +- .../webui/servlet/DeleteSessionServlet.kt | 4 +- .../skyenet/webui/servlet/FileServlet.kt | 310 +- .../skyenet/webui/servlet/LogoutServlet.kt | 3 +- .../skyenet/webui/servlet/OAuthGoogle.kt | 174 +- .../skyenet/webui/servlet/ProxyHttpServlet.kt | 330 +- .../skyenet/webui/servlet/SessionIdFilter.kt | 44 +- .../webui/servlet/SessionListServlet.kt | 10 +- .../webui/servlet/SessionSettingsServlet.kt | 2 +- .../webui/servlet/SessionShareServlet.kt | 199 +- .../webui/servlet/SessionThreadsServlet.kt | 18 +- .../skyenet/webui/servlet/ToolServlet.kt | 348 +- .../skyenet/webui/servlet/UsageServlet.kt | 10 +- .../webui/servlet/UserSettingsServlet.kt | 49 +- .../skyenet/webui/servlet/WelcomeServlet.kt | 107 +- .../skyenet/webui/servlet/ZipServlet.kt | 3 +- .../skyenet/webui/session/SessionTask.kt | 325 +- .../webui/session/SocketManagerBase.kt | 428 +- .../skyenet/webui/test/CodingActorTestApp.kt | 27 +- .../skyenet/webui/test/ImageActorTestApp.kt | 11 +- .../skyenet/webui/test/ParsedActorTestApp.kt | 8 +- .../skyenet/webui/test/SimpleActorTestApp.kt | 10 +- .../skyenet/webui/util/EncryptFiles.kt | 14 +- .../skyenet/webui/util/MarkdownUtil.kt | 35 +- .../skyenet/webui/util/OpenAPI.kt | 125 +- .../skyenet/webui/util/Selenium2S3.kt | 822 +-- .../skyenet/webui/util/TensorflowProjector.kt | 13 +- .../github/simiacryptus/diff/DiffUtilTest.kt | 158 +- .../diff/IterativePatchUtilTest.kt | 39 +- .../skyenet/webui/ActorTestAppServer.kt | 48 +- 106 files changed, 8802 insertions(+), 8728 deletions(-) rename core_user_documentation.md => docs/core_user_documentation.md (100%) rename webui_documentation.md => docs/webui_documentation.md (99%) delete mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingSubAgent.kt diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 821e9d2c..40f48621 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -32,7 +32,7 @@ val jackson_version = "2.15.3" dependencies { - implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.51") + implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.52") implementation("org.apache.commons:commons-text:1.11.0") diff --git a/core/src/main/dev_documentation.md b/core/src/main/dev_documentation.md index 1b22fb13..4041a588 100644 --- a/core/src/main/dev_documentation.md +++ b/core/src/main/dev_documentation.md @@ -3450,7 +3450,7 @@ interface Selenium : AutoCloseable { url: URL, currentFilename: String?, saveRoot: String - ); + ) } ``` diff --git a/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/CommonRoot.kt b/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/CommonRoot.kt index 2e68c6d7..a727309f 100644 --- a/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/CommonRoot.kt +++ b/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/CommonRoot.kt @@ -3,28 +3,28 @@ package com.github.simiacryptus.aicoder.actions.generic import java.io.File import java.nio.file.Path -fun Array.commonRoot() : Path = when { - isEmpty() -> error("No paths") - size == 1 && first().toFile().isFile -> first().parent - size == 1 -> first() - else -> this.reduce { a, b -> - when { - a.startsWith(b) -> b - b.startsWith(a) -> a - else -> when (val common = a.commonPrefixWith(b)) { - a -> a - b -> b - else -> common.toAbsolutePath() - } +fun Array.commonRoot(): Path = when { + isEmpty() -> error("No paths") + size == 1 && first().toFile().isFile -> first().parent + size == 1 -> first() + else -> this.reduce { a, b -> + when { + a.startsWith(b) -> b + b.startsWith(a) -> a + else -> when (val common = a.commonPrefixWith(b)) { + a -> a + b -> b + else -> common.toAbsolutePath() + } + } } - } } private fun Path.commonPrefixWith(b: Path): Path { - val a = this - val aParts = a.toAbsolutePath().toString().split(File.separator) - val bParts = b.toAbsolutePath().toString().split(File.separator) - val common = aParts.zip(bParts).takeWhile { (a, b) -> a == b }.map { it.first } - return File(File.separator + common.joinToString(File.separator)).toPath() + val a = this + val aParts = a.toAbsolutePath().toString().split(File.separator) + val bParts = b.toAbsolutePath().toString().split(File.separator) + val common = aParts.zip(bParts).takeWhile { (a, b) -> a == b }.map { it.first } + return File(File.separator + common.joinToString(File.separator)).toPath() } diff --git a/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/GetModuleRootForFile.kt b/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/GetModuleRootForFile.kt index ba54d09e..e239392e 100644 --- a/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/GetModuleRootForFile.kt +++ b/core/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/GetModuleRootForFile.kt @@ -3,12 +3,12 @@ package com.github.simiacryptus.aicoder.actions.generic import java.io.File fun getModuleRootForFile(file: File): File { - var current = file - while (current.parentFile != null) { - if (current.resolve(".git").exists()) { - return current + var current = file + while (current.parentFile != null) { + if (current.resolve(".git").exists()) { + return current + } + current = current.parentFile } - current = current.parentFile - } - return file + return file } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt index 04c6b6f3..9b981351 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt @@ -10,66 +10,66 @@ import com.simiacryptus.skyenet.core.util.JsonFunctionRecorder import java.io.File open class ActorSystem>( - val actors: Map>, - val dataStorage: StorageInterface, - val user: User?, - val session: Session + val actors: Map>, + val dataStorage: StorageInterface, + val user: User?, + val session: Session ) { - private val sessionDir = dataStorage.getSessionDir(user, session) - protected val pool by lazy { ApplicationServices.clientManager.getPool(session, user, dataStorage) } + private val sessionDir = dataStorage.getSessionDir(user, session) + protected val pool by lazy { ApplicationServices.clientManager.getPool(session, user, dataStorage) } - private val actorMap = mutableMapOf>() + private val actorMap = mutableMapOf>() - fun getActor(actor: T): BaseActor<*, *> { - return synchronized(actorMap) { - actorMap.computeIfAbsent(actor) { innerActor -> - try { - val wrapper = getWrapper(actor.name) - when (val baseActor = actors[actor.name]) { - null -> throw RuntimeException("No actor for $actor") - is SimpleActor -> SimpleActorInterceptor( - inner = baseActor as SimpleActor, - functionInterceptor = wrapper - ) + fun getActor(actor: T): BaseActor<*, *> { + return synchronized(actorMap) { + actorMap.computeIfAbsent(actor) { innerActor -> + try { + val wrapper = getWrapper(actor.name) + when (val baseActor = actors[actor.name]) { + null -> throw RuntimeException("No actor for $actor") + is SimpleActor -> SimpleActorInterceptor( + inner = baseActor, + functionInterceptor = wrapper + ) - is ParsedActor<*> -> ParsedActorInterceptor( - inner = (baseActor as ParsedActor<*>), - functionInterceptor = wrapper - ) + is ParsedActor<*> -> ParsedActorInterceptor( + inner = baseActor, + functionInterceptor = wrapper + ) - is CodingActor -> CodingActorInterceptor( - inner = baseActor as CodingActor, - functionInterceptor = wrapper - ) + is CodingActor -> CodingActorInterceptor( + inner = baseActor, + functionInterceptor = wrapper + ) - is ImageActor -> ImageActorInterceptor( - inner = baseActor as ImageActor, - functionInterceptor = wrapper - ) + is ImageActor -> ImageActorInterceptor( + inner = baseActor, + functionInterceptor = wrapper + ) - is TextToSpeechActor -> TextToSpeechActorInterceptor( - inner = baseActor as TextToSpeechActor, - functionInterceptor = wrapper - ) + is TextToSpeechActor -> TextToSpeechActorInterceptor( + inner = baseActor, + functionInterceptor = wrapper + ) - else -> throw RuntimeException("Unknown actor type: ${baseActor.javaClass}") - } - } catch (e: Throwable) { - log.warn("Error creating actor $actor", e) - actors[actor.name]!! + else -> throw RuntimeException("Unknown actor type: ${baseActor.javaClass}") + } + } catch (e: Throwable) { + log.warn("Error creating actor $actor", e) + actors[actor.name]!! + } + } } - } } - } - private val wrapperMap = mutableMapOf() - private fun getWrapper(name: String) = synchronized(wrapperMap) { - wrapperMap.getOrPut(name) { - FunctionWrapper(JsonFunctionRecorder(File(sessionDir, ".sys/$session/actors/$name"))) + private val wrapperMap = mutableMapOf() + private fun getWrapper(name: String) = synchronized(wrapperMap) { + wrapperMap.getOrPut(name) { + FunctionWrapper(JsonFunctionRecorder(File(sessionDir, ".sys/$session/actors/$name"))) + } } - } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ActorSystem::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ActorSystem::class.java) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt index 05fb4616..9b6595a3 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt @@ -7,23 +7,25 @@ import com.simiacryptus.jopenai.models.ChatModels import com.simiacryptus.jopenai.models.OpenAIModel import com.simiacryptus.jopenai.models.OpenAITextModel -abstract class BaseActor( +abstract class BaseActor( open val prompt: String, val name: String? = null, val model: ChatModels, val temperature: Double = 0.3, ) { abstract fun respond(input: I, api: API, vararg messages: ApiModel.ChatMessage): R - open fun response(vararg input: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = (api as OpenAIClient).chat( - ApiModel.ChatRequest( - messages = ArrayList(input.toList()), - temperature = temperature, - model = this.model.modelName, - ), - model = this.model - ) - open fun answer(input: I, api: API): R = respond(input=input, api = api, *chatMessages(input)) + open fun response(vararg input: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = + (api as OpenAIClient).chat( + ApiModel.ChatRequest( + messages = ArrayList(input.toList()), + temperature = temperature, + model = this.model.modelName, + ), + model = this.model + ) + + open fun answer(input: I, api: API): R = respond(input = input, api = api, *chatMessages(input)) abstract fun chatMessages(questions: I): Array - abstract fun withModel(model: ChatModels): BaseActor + abstract fun withModel(model: ChatModels): BaseActor } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt index 4c78d26b..1b13cac8 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt @@ -14,58 +14,59 @@ import javax.script.ScriptException import kotlin.reflect.KClass open class CodingActor( - val interpreterClass: KClass, - val symbols: Map = mapOf(), - val describer: TypeDescriber = AbbrevWhitelistYamlDescriber( - "com.simiacryptus", - "com.github.simiacryptus" - ), - name: String? = interpreterClass.simpleName, - val details: String? = null, - model: ChatModels, - val fallbackModel: ChatModels = ChatModels.GPT4Turbo, - temperature: Double = 0.1, - val runtimeSymbols: Map = mapOf() + val interpreterClass: KClass, + val symbols: Map = mapOf(), + val describer: TypeDescriber = AbbrevWhitelistYamlDescriber( + "com.simiacryptus", + "com.github.simiacryptus" + ), + name: String? = interpreterClass.simpleName, + val details: String? = null, + model: ChatModels, + val fallbackModel: ChatModels = ChatModels.GPT4Turbo, + temperature: Double = 0.1, + val runtimeSymbols: Map = mapOf() ) : BaseActor( - prompt = "", - name = name, - model = model, - temperature = temperature, + prompt = "", + name = name, + model = model, + temperature = temperature, ) { - val interpreter: Interpreter - get() = interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) - - data class CodeRequest( - val messages: List>, - val codePrefix: String = "", - val autoEvaluate: Boolean = false, - val fixIterations: Int = 1, - val fixRetries: Int = 1, - ) - - interface CodeResult { - enum class Status { - Coding, Correcting, Success, Failure + val interpreter: Interpreter + get() = interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) + + data class CodeRequest( + val messages: List>, + val codePrefix: String = "", + val autoEvaluate: Boolean = false, + val fixIterations: Int = 1, + val fixRetries: Int = 1, + ) + + interface CodeResult { + enum class Status { + Coding, Correcting, Success, Failure + } + + val code: String + val status: Status + val result: ExecutionResult + val renderedResponse: String? } - val code: String - val status: Status - val result: ExecutionResult - val renderedResponse: String? - } - - data class ExecutionResult( - val resultValue: String, - val resultOutput: String - ) - - var evalFormat = true - override val prompt: String - get() { - val formatInstructions = if(evalFormat) """Code should be structured as appropriately parameterized function(s) + data class ExecutionResult( + val resultValue: String, + val resultOutput: String + ) + + var evalFormat = true + override val prompt: String + get() { + val formatInstructions = + if (evalFormat) """Code should be structured as appropriately parameterized function(s) |with the final line invoking the function with the appropriate request parameters.""" else "" - return if (symbols.isNotEmpty()) { - """ + return if (symbols.isNotEmpty()) { + """ |You are a coding assistant allows users actions to be enacted using $language and the script context. |Your role is to translate natural language instructions into code as well as interpret the results and converse with the user. |Use ``` code blocks labeled with $language where appropriate. (i.e. ```$language) @@ -80,7 +81,7 @@ open class CodingActor( | |${details ?: ""} |""".trimMargin().trim() - } else """ + } else """ |You are a coding assistant allows users actions to be enacted using $language and the script context. |Your role is to translate natural language instructions into code as well as interpret the results and converse with the user. |Use ``` code blocks labeled with $language where appropriate. (i.e. ```$language) @@ -89,399 +90,407 @@ open class CodingActor( | |${details ?: ""} |""".trimMargin().trim() - } + } - open val apiDescription: String - get() = this.symbols.map { (name, utilityObj) -> - """ + open val apiDescription: String + get() = this.symbols.map { (name, utilityObj) -> + """ |$name: | ${this.describer.describe(utilityObj.javaClass).indent(" ")} |""".trimMargin().trim() - }.joinToString("\n") - - - val language: String by lazy { interpreter.getLanguage() } - - override fun chatMessages(questions: CodeRequest): Array { - var chatMessages = arrayOf( - ChatMessage( - role = Role.system, - content = prompt.toContentList() - ), - ) + questions.messages.map { - ChatMessage( - role = it.second, - content = it.first.toContentList() - ) - } - if (questions.codePrefix.isNotBlank()) { - chatMessages = (chatMessages.dropLast(1) + listOf( - ChatMessage(Role.assistant, questions.codePrefix.toContentList()) - ) + chatMessages.last()).toTypedArray() + }.joinToString("\n") + + + val language: String by lazy { interpreter.getLanguage() } + + override fun chatMessages(questions: CodeRequest): Array { + var chatMessages = arrayOf( + ChatMessage( + role = Role.system, + content = prompt.toContentList() + ), + ) + questions.messages.map { + ChatMessage( + role = it.second, + content = it.first.toContentList() + ) + } + if (questions.codePrefix.isNotBlank()) { + chatMessages = (chatMessages.dropLast(1) + listOf( + ChatMessage(Role.assistant, questions.codePrefix.toContentList()) + ) + chatMessages.last()).toTypedArray() + } + return chatMessages + } - return chatMessages - - } - - override fun respond( - input: CodeRequest, - api: API, - vararg messages: ChatMessage, - ): CodeResult { - var result = CodeResultImpl( - *messages, - input = input, - api = (api as OpenAIClient) - ) - if (!input.autoEvaluate) return result - for (i in 0..input.fixIterations) try { - require(result.result.resultValue.length > -1) - return result - } catch (ex: Throwable) { - if (i == input.fixIterations) { - log.info( - "Failed to implement ${ - messages.map { it.content?.joinToString("\n") { it.text ?: "" } }.joinToString("\n") - }" + + override fun respond( + input: CodeRequest, + api: API, + vararg messages: ChatMessage, + ): CodeResult { + var result = CodeResultImpl( + *messages, + input = input, + api = (api as OpenAIClient) ) - throw ex - } - val respondWithCode = fixCommand(api, result.code, ex, *messages, model = model) - val blocks = extractTextBlocks(respondWithCode) - val renderedResponse = getRenderedResponse(blocks) - val codedInstruction = getCode(language, blocks) - log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) - log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) - result = CodeResultImpl( - *messages, - input = input, - api = api, - givenCode = codedInstruction, - givenResponse = renderedResponse - ) + if (!input.autoEvaluate) return result + for (i in 0..input.fixIterations) try { + require(result.result.resultValue.length > -1) + return result + } catch (ex: Throwable) { + if (i == input.fixIterations) { + log.info( + "Failed to implement ${ + messages.map { it.content?.joinToString("\n") { it.text ?: "" } }.joinToString("\n") + }" + ) + throw ex + } + val respondWithCode = fixCommand(api, result.code, ex, *messages, model = model) + val blocks = extractTextBlocks(respondWithCode) + val renderedResponse = getRenderedResponse(blocks) + val codedInstruction = getCode(language, blocks) + log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) + log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) + result = CodeResultImpl( + *messages, + input = input, + api = api, + givenCode = codedInstruction, + givenResponse = renderedResponse + ) + } + throw IllegalStateException() } - throw IllegalStateException() - } - - open fun execute(prefix: String, code: String): ExecutionResult { - //language=HTML - log.debug("Running $code") - OutputInterceptor.clearGlobalOutput() - val result = try { - interpreter.run((prefix + "\n" + code).sortCode()) - } catch (e: Exception) { - when { - e is FailedToImplementException -> throw e - e is ScriptException -> throw FailedToImplementException( - cause = e, - message = errorMessage(e, code), - language = language, - code = code, - prefix = prefix, - ) - e.cause is ScriptException -> throw FailedToImplementException( - cause = e, - message = errorMessage(e.cause!! as ScriptException, code), - language = language, - code = code, - prefix = prefix, - ) + open fun execute(prefix: String, code: String): ExecutionResult { + //language=HTML + log.debug("Running $code") + OutputInterceptor.clearGlobalOutput() + val result = try { + interpreter.run((prefix + "\n" + code).sortCode()) + } catch (e: Exception) { + when { + e is FailedToImplementException -> throw e + e is ScriptException -> throw FailedToImplementException( + cause = e, + message = errorMessage(e, code), + language = language, + code = code, + prefix = prefix, + ) - else -> throw e - } - } - log.debug("Result: $result") - //language=HTML - val executionResult = ExecutionResult(result.toString(), OutputInterceptor.getThreadOutput()) - OutputInterceptor.clearThreadOutput() - return executionResult - } - - inner class CodeResultImpl( - vararg val messages: ChatMessage, - private val input: CodeRequest, - private val api: OpenAIClient, - private val givenCode: String? = null, - private val givenResponse: String? = null, - ) : CodeResult { - private val implementation by lazy { - if (!givenCode.isNullOrBlank() && !givenResponse.isNullOrBlank()) (givenCode to givenResponse) else try { - implement(model) - } catch (ex: FailedToImplementException) { - if (fallbackModel != model) { - try { - implement(fallbackModel) - } catch (ex: FailedToImplementException) { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Failure - throw ex - } - } else { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Failure - throw ex + e.cause is ScriptException -> throw FailedToImplementException( + cause = e, + message = errorMessage(e.cause!! as ScriptException, code), + language = language, + code = code, + prefix = prefix, + ) + + else -> throw e + } } - } + log.debug("Result: $result") + //language=HTML + val executionResult = ExecutionResult(result.toString(), OutputInterceptor.getThreadOutput()) + OutputInterceptor.clearThreadOutput() + return executionResult } - private var _status = CodeResult.Status.Coding - - override val status get() = _status - - override val renderedResponse: String = givenResponse ?: implementation.second - override val code: String = givenCode ?: implementation.first - - private fun implement( - model: ChatModels, - ): Pair { - val request = ChatRequest(messages = ArrayList(this.messages.toList())) - for (codingAttempt in 0..input.fixRetries) { - try { - val codeBlocks = extractTextBlocks(chat(api, request, model)) - val renderedResponse = getRenderedResponse(codeBlocks) - val codedInstruction = getCode(language, codeBlocks) - log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) - log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) - var workingCode = codedInstruction - var workingRenderedResponse = renderedResponse - for (fixAttempt in 0..input.fixIterations) { - try { - val validate = interpreter.validate((input.codePrefix + "\n" + workingCode).sortCode()) - if (validate != null) throw validate - log.debug("Validation succeeded") - _status = CodeResult.Status.Success - return workingCode to workingRenderedResponse - } catch (ex: Throwable) { - if (fixAttempt == input.fixIterations) - throw if (ex is FailedToImplementException) ex else FailedToImplementException( - cause = ex, - message = """ + inner class CodeResultImpl( + vararg val messages: ChatMessage, + private val input: CodeRequest, + private val api: OpenAIClient, + private val givenCode: String? = null, + private val givenResponse: String? = null, + ) : CodeResult { + private val implementation by lazy { + if (!givenCode.isNullOrBlank() && !givenResponse.isNullOrBlank()) (givenCode to givenResponse) else try { + implement(model) + } catch (ex: FailedToImplementException) { + if (fallbackModel != model) { + try { + implement(fallbackModel) + } catch (ex: FailedToImplementException) { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Failure + throw ex + } + } else { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Failure + throw ex + } + } + } + + private var _status = CodeResult.Status.Coding + + override val status get() = _status + + override val renderedResponse: String = givenResponse ?: implementation.second + override val code: String = givenCode ?: implementation.first + + private fun implement( + model: ChatModels, + ): Pair { + val request = ChatRequest(messages = ArrayList(this.messages.toList())) + for (codingAttempt in 0..input.fixRetries) { + try { + val codeBlocks = extractTextBlocks(chat(api, request, model)) + val renderedResponse = getRenderedResponse(codeBlocks) + val codedInstruction = getCode(language, codeBlocks) + log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) + log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) + var workingCode = codedInstruction + var workingRenderedResponse = renderedResponse + for (fixAttempt in 0..input.fixIterations) { + try { + val validate = interpreter.validate((input.codePrefix + "\n" + workingCode).sortCode()) + if (validate != null) throw validate + log.debug("Validation succeeded") + _status = CodeResult.Status.Success + return workingCode to workingRenderedResponse + } catch (ex: Throwable) { + if (fixAttempt == input.fixIterations) + throw if (ex is FailedToImplementException) ex else FailedToImplementException( + cause = ex, + message = """ |**ERROR** | |```text |${ex.stackTraceToString()} |``` |""".trimMargin().trim(), - language = language, - code = workingCode, - prefix = input.codePrefix - ) - log.debug("Validation failed - ${ex.message}") - _status = CodeResult.Status.Correcting - val respondWithCode = fixCommand(api, workingCode, ex, *messages, model = model) - val codeBlocks = extractTextBlocks(respondWithCode) - workingRenderedResponse = getRenderedResponse(codeBlocks) - workingCode = getCode(language, codeBlocks) - log.debug("Response: \n\t${workingRenderedResponse.replace("\n", "\n\t", false)}".trimMargin()) - log.debug("New Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin()) + language = language, + code = workingCode, + prefix = input.codePrefix + ) + log.debug("Validation failed - ${ex.message}") + _status = CodeResult.Status.Correcting + val respondWithCode = fixCommand(api, workingCode, ex, *messages, model = model) + val codeBlocks = extractTextBlocks(respondWithCode) + workingRenderedResponse = getRenderedResponse(codeBlocks) + workingCode = getCode(language, codeBlocks) + log.debug( + "Response: \n\t${ + workingRenderedResponse.replace( + "\n", + "\n\t", + false + ) + }".trimMargin() + ) + log.debug("New Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin()) + } + } + } catch (ex: FailedToImplementException) { + if (codingAttempt == input.fixRetries) { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + throw ex + } + log.debug("Retry failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Correcting + } } - } - } catch (ex: FailedToImplementException) { - if (codingAttempt == input.fixRetries) { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - throw ex - } - log.debug("Retry failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Correcting + throw IllegalStateException() } - } - throw IllegalStateException() - } - private val executionResult by lazy { execute(input.codePrefix, code) } - - override val result get() = executionResult - } - - private fun fixCommand( - api: OpenAIClient, - previousCode: String, - error: Throwable, - vararg promptMessages: ChatMessage, - model: ChatModels - ): String = chat( - api = api, - request = ChatRequest( - messages = ArrayList( - promptMessages.toList() + listOf( - ChatMessage( - Role.assistant, - """ + private val executionResult by lazy { execute(input.codePrefix, code) } + + override val result get() = executionResult + } + + private fun fixCommand( + api: OpenAIClient, + previousCode: String, + error: Throwable, + vararg promptMessages: ChatMessage, + model: ChatModels + ): String = chat( + api = api, + request = ChatRequest( + messages = ArrayList( + promptMessages.toList() + listOf( + ChatMessage( + Role.assistant, + """ |```${language.lowercase()} - |${previousCode?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} + |${previousCode.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` |""".trimMargin().trim().toContentList() - ), - ChatMessage( - Role.system, - """ + ), + ChatMessage( + Role.system, + """ |The previous code failed with the following error: | |``` - |${error.message?.trim() ?: ""?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} + |${error.message?.trim() ?: "".let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` | |Correct the code and try again. |""".trimMargin().trim().toContentList() - ) - ) - ) - ), - model = model - ) + ) + ) + ) + ), + model = model + ) - private fun chat(api: OpenAIClient, request: ChatRequest, model: ChatModels) = - api.chat(request.copy(model = model.modelName, temperature = temperature), model) - .choices.first().message?.content.orEmpty().trim() + private fun chat(api: OpenAIClient, request: ChatRequest, model: ChatModels) = + api.chat(request.copy(model = model.modelName, temperature = temperature), model) + .choices.first().message?.content.orEmpty().trim() + + + override fun withModel(model: ChatModels): CodingActor = CodingActor( + interpreterClass = interpreterClass, + symbols = symbols, + describer = describer, + name = name, + details = details, + model = model, + fallbackModel = fallbackModel, + temperature = temperature, + runtimeSymbols = runtimeSymbols + ) + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(CodingActor::class.java) - override fun withModel(model: ChatModels): CodingActor = CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - describer = describer, - name = name, - details = details, - model = model, - fallbackModel = fallbackModel, - temperature = temperature, - runtimeSymbols = runtimeSymbols - ) + fun String.indent(indent: String = " ") = this.replace("\n", "\n$indent") - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(CodingActor::class.java) + fun extractTextBlocks(response: String): List> { + val codeBlockRegex = Regex("(?s)```(.*?)\\n(.*?)```") + val languageRegex = Regex("([a-zA-Z0-9-_]+)") - fun String.indent(indent: String = " ") = this.replace("\n", "\n$indent") + val result = mutableListOf>() + var startIndex = 0 - fun extractTextBlocks(response: String): List> { - val codeBlockRegex = Regex("(?s)```(.*?)\\n(.*?)```") - val languageRegex = Regex("([a-zA-Z0-9-_]+)") + val matches = codeBlockRegex.findAll(response) + if (matches.count() == 0) return listOf(Pair("text", response)) + for (match in matches) { + // Add non-code block before the current match as "text" + if (startIndex < match.range.first) { + result.add(Pair("text", response.substring(startIndex, match.range.first))) + } - val result = mutableListOf>() - var startIndex = 0 + // Extract language and code + val languageMatch = languageRegex.find(match.groupValues[1]) + val language = languageMatch?.groupValues?.get(0) ?: "code" + val code = match.groupValues[2] - val matches = codeBlockRegex.findAll(response) - if (matches.count() == 0) return listOf(Pair("text", response)) - for (match in matches) { - // Add non-code block before the current match as "text" - if (startIndex < match.range.first) { - result.add(Pair("text", response.substring(startIndex, match.range.first))) - } + // Add code block to the result + result.add(Pair(language, code)) - // Extract language and code - val languageMatch = languageRegex.find(match.groupValues[1]) - val language = languageMatch?.groupValues?.get(0) ?: "code" - val code = match.groupValues[2] + // Update the start index + startIndex = match.range.last + 1 + } - // Add code block to the result - result.add(Pair(language, code)) + // Add any remaining non-code text after the last code block as "text" + if (startIndex < response.length) { + result.add(Pair("text", response.substring(startIndex))) + } - // Update the start index - startIndex = match.range.last + 1 - } + return result + } - // Add any remaining non-code text after the last code block as "text" - if (startIndex < response.length) { - result.add(Pair("text", response.substring(startIndex))) - } + fun getRenderedResponse(respondWithCode: List>, defaultLanguage: String = "") = + respondWithCode.joinToString("\n") { + when (it.first) { + "code" -> "```$defaultLanguage\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```" + "text" -> it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }.toString() + else -> "```${it.first}\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```" + } + } - return result - } + fun getCode(language: String, textSegments: List>): String { + if (textSegments.size == 1) return textSegments.joinToString("\n") { it.second } + return textSegments.joinToString("\n") { + if (it.first.lowercase() == "code" || it.first.lowercase() == language.lowercase()) { + it.second.trimMargin().trim() + } else { + "" + } + } + } - fun getRenderedResponse(respondWithCode: List>, defaultLanguage: String = "") = - respondWithCode.joinToString("\n") { - when (it.first) { - "code" -> "```$defaultLanguage\n${it.second?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```" - "text" -> it.second?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }.toString() - else -> "```${it.first}\n${it.second?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```" + fun String.sortCode(bodyWrapper: (String) -> String = { it }): String { + val (imports, otherCode) = this.split("\n").partition { it.trim().startsWith("import ") } + return imports.distinct().sorted().joinToString("\n") + "\n\n" + bodyWrapper(otherCode.joinToString("\n")) } - } - - fun getCode(language: String, textSegments: List>): String { - if (textSegments.size == 1) return textSegments.joinToString("\n") { it.second } - return textSegments.joinToString("\n") { - if (it.first.lowercase() == "code" || it.first.lowercase() == language.lowercase()) { - it.second.trimMargin().trim() - } else { - "" + + fun String.camelCase(locale: Locale = Locale.getDefault()): String { + val words = fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() } + return words.first().lowercase(locale) + words.drop(1).joinToString("") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } + } + } } - } - } - fun String.sortCode(bodyWrapper: (String) -> String = { it }): String { - val (imports, otherCode) = this.split("\n").partition { it.trim().startsWith("import ") } - return imports.distinct().sorted().joinToString("\n") + "\n\n" + bodyWrapper(otherCode.joinToString("\n")) - } + fun String.pascalCase(locale: Locale = Locale.getDefault()): String = + fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } + } + } - fun String.camelCase(locale: Locale = Locale.getDefault()): String { - val words = fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() } - return words.first().lowercase(locale) + words.drop(1).joinToString("") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } + // Detect changes in the case of the first letter and prepend a space + private fun String.fromPascalCase(): String = buildString { + var lastChar = ' ' + for (c in this@fromPascalCase) { + if (c.isUpperCase() && lastChar.isLowerCase()) append(' ') + append(c) + lastChar = c + } } - } - } - fun String.pascalCase(locale: Locale = Locale.getDefault()): String = - fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } + fun String.upperSnakeCase(locale: Locale = Locale.getDefault()): String = + fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("_") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } + } + }.uppercase(locale) + + fun String.imports(): List { + return this.split("\n").filter { it.trim().startsWith("import ") }.distinct().sorted() } - } - - // Detect changes in the case of the first letter and prepend a space - private fun String.fromPascalCase(): String = buildString { - var lastChar = ' ' - for (c in this@fromPascalCase) { - if (c.isUpperCase() && lastChar.isLowerCase()) append(' ') - append(c) - lastChar = c - } - } - fun String.upperSnakeCase(locale: Locale = Locale.getDefault()): String = - fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("_") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } + fun String.stripImports(): String { + return this.split("\n").filter { !it.trim().startsWith("import ") }.joinToString("\n") } - }.uppercase(locale) - fun String.imports(): List { - return this.split("\n").filter { it.trim().startsWith("import ") }.distinct().sorted() - } - - fun String.stripImports(): String { - return this.split("\n").filter { !it.trim().startsWith("import ") }.joinToString("\n") - } - - fun errorMessage(ex: ScriptException, code: String) = try { - """ + fun errorMessage(ex: ScriptException, code: String) = try { + """ |```text |${ex.message ?: ""} at line ${ex.lineNumber} column ${ex.columnNumber} | ${if (ex.lineNumber > 0) code.split("\n")[ex.lineNumber - 1] else ""} | ${if (ex.columnNumber > 0) " ".repeat(ex.columnNumber - 1) + "^" else ""} |``` """.trimMargin().trim() - } catch (_: Exception) { - ex.message ?: "" + } catch (_: Exception) { + ex.message ?: "" + } } - } - - class FailedToImplementException( - cause: Throwable? = null, - message: String = "Failed to implement", - val language: String? = null, - val code: String? = null, - val prefix: String? = null, - ) : RuntimeException(message, cause) + + class FailedToImplementException( + cause: Throwable? = null, + message: String = "Failed to implement", + val language: String? = null, + val code: String? = null, + val prefix: String? = null, + ) : RuntimeException(message, cause) } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt index e0174f1b..43292f41 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt @@ -63,7 +63,7 @@ open class ImageActor( return ImageIO.read(URL(url)) } - override fun respond(input: List, api: API, vararg messages: ChatMessage): ImageResponse { + override fun respond(input: List, api: API, vararg messages: ChatMessage): ImageResponse { var text = response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") while (imageModel.maxPrompt <= text.length) { diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/MultiExeption.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/MultiExeption.kt index 44e3d02c..02c61be5 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/MultiExeption.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/MultiExeption.kt @@ -3,6 +3,5 @@ package com.simiacryptus.skyenet.core.actors import com.simiacryptus.skyenet.core.actors.CodingActor.Companion.indent class MultiExeption(exceptions: Collection) : RuntimeException( - exceptions.joinToString("\n\n") { "```text\n${/*escapeHtml4*/(it.stackTraceToString())/*.indent(" ")*/}\n```" } -) { -} + exceptions.joinToString("\n\n") { "```text\n${/*escapeHtml4*/(it.stackTraceToString())/*.indent(" ")*/}\n```" } +) diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt index 6e484465..7b106808 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt @@ -12,55 +12,56 @@ import com.simiacryptus.skyenet.core.actors.CodingActor.Companion.indent import java.util.function.Function open class ParsedActor( - var resultClass: Class? = null, - val exampleInstance: T? = resultClass?.getConstructor()?.newInstance(), - prompt: String = "", - name: String? = resultClass?.simpleName, - model: ChatModels = ChatModels.GPT4Turbo, - temperature: Double = 0.3, - val parsingModel: ChatModels = ChatModels.GPT35Turbo, - val deserializerRetries: Int = 2, - open val describer: TypeDescriber = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false - }, + var resultClass: Class? = null, + val exampleInstance: T? = resultClass?.getConstructor()?.newInstance(), + prompt: String = "", + name: String? = resultClass?.simpleName, + model: ChatModels = ChatModels.GPT4Turbo, + temperature: Double = 0.3, + val parsingModel: ChatModels = ChatModels.GPT35Turbo, + val deserializerRetries: Int = 2, + open val describer: TypeDescriber = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false + }, ) : BaseActor, ParsedResponse>( - prompt = prompt, - name = name, - model = model, - temperature = temperature, + prompt = prompt, + name = name, + model = model, + temperature = temperature, ) { - init { - requireNotNull(resultClass) { - "Result class is required" + init { + requireNotNull(resultClass) { + "Result class is required" + } } - } - override fun chatMessages(questions: List) = arrayOf( - ApiModel.ChatMessage( - role = ApiModel.Role.system, - content = prompt.toContentList() - ), - ) + questions.map { - ApiModel.ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - } - private inner class ParsedResponseImpl(vararg messages: ApiModel.ChatMessage, api: API) : - ParsedResponse(resultClass!!) { - override val text = - response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") - private val _obj: T by lazy { getParser(api).apply(text) } - override val obj get() = _obj - } + override fun chatMessages(questions: List) = arrayOf( + ApiModel.ChatMessage( + role = ApiModel.Role.system, + content = prompt.toContentList() + ), + ) + questions.map { + ApiModel.ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() + ) + } - fun getParser(api: API) = Function { input -> - describer.coverMethods = false - val describe = resultClass?.let { describer.describe(it) } ?: "" - val exceptions = mutableListOf() - val prompt = """ + private inner class ParsedResponseImpl(vararg messages: ApiModel.ChatMessage, api: API) : + ParsedResponse(resultClass!!) { + override val text = + response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") + private val _obj: T by lazy { getParser(api).apply(text) } + override val obj get() = _obj + } + + fun getParser(api: API) = Function { input -> + describer.coverMethods = false + val describe = resultClass?.let { describer.describe(it) } ?: "" + val exceptions = mutableListOf() + val prompt = """ |Parse the user's message into a json object described by: | |```yaml @@ -73,62 +74,67 @@ open class ParsedActor( |``` | """.trimMargin() - for (i in 0 until deserializerRetries) { - try { - val content = (api as OpenAIClient).chat( - ApiModel.ChatRequest( - messages = listOf( - ApiModel.ChatMessage(role = ApiModel.Role.system, content = prompt.toContentList()), - ApiModel.ChatMessage(role = ApiModel.Role.user, content = input.toContentList()), - ), - temperature = temperature, - model = model.modelName, - ), - model = model, - ).choices.first().message?.content - var contentUnwrapped = content?.trim() ?: throw RuntimeException("No response") + for (i in 0 until deserializerRetries) { + try { + val content = (api as OpenAIClient).chat( + ApiModel.ChatRequest( + messages = listOf( + ApiModel.ChatMessage(role = ApiModel.Role.system, content = prompt.toContentList()), + ApiModel.ChatMessage(role = ApiModel.Role.user, content = input.toContentList()), + ), + temperature = temperature, + model = model.modelName, + ), + model = model, + ).choices.first().message?.content + var contentUnwrapped = content?.trim() ?: throw RuntimeException("No response") - // If Plaintext is found before the { or ```, strip it - if (!contentUnwrapped.startsWith("{") && !contentUnwrapped.startsWith("```")) { - val start = contentUnwrapped.indexOf("{").coerceAtMost(contentUnwrapped.indexOf("```")) - val end = contentUnwrapped.lastIndexOf("}").coerceAtLeast(contentUnwrapped.lastIndexOf("```") + 2) + 1 - if (start < end && start >= 0) contentUnwrapped = contentUnwrapped.substring(start, end) - } + // If Plaintext is found before the { or ```, strip it + if (!contentUnwrapped.startsWith("{") && !contentUnwrapped.startsWith("```")) { + val start = contentUnwrapped.indexOf("{").coerceAtMost(contentUnwrapped.indexOf("```")) + val end = + contentUnwrapped.lastIndexOf("}").coerceAtLeast(contentUnwrapped.lastIndexOf("```") + 2) + 1 + if (start < end && start >= 0) contentUnwrapped = contentUnwrapped.substring(start, end) + } - // if input is wrapped in a ```json block, remove the block - if (contentUnwrapped.startsWith("```json") && contentUnwrapped.endsWith("```")) { - contentUnwrapped = contentUnwrapped.substring(7, contentUnwrapped.length - 3) - } + // if input is wrapped in a ```json block, remove the block + if (contentUnwrapped.startsWith("```json") && contentUnwrapped.endsWith("```")) { + contentUnwrapped = contentUnwrapped.substring(7, contentUnwrapped.length - 3) + } - contentUnwrapped.let { return@Function JsonUtil.fromJson(it, resultClass - ?: throw RuntimeException("Result class undefined")) } - } catch (e: Exception) { - log.info("Failed to parse response", e) - exceptions.add(e) - } + contentUnwrapped.let { + return@Function JsonUtil.fromJson( + it, resultClass + ?: throw RuntimeException("Result class undefined") + ) + } + } catch (e: Exception) { + log.info("Failed to parse response", e) + exceptions.add(e) + } + } + throw MultiExeption(exceptions) } - throw MultiExeption(exceptions) - } - override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): ParsedResponse { - try { - return ParsedResponseImpl(*messages, api = api) - } catch (e: Exception) { - log.info("Failed to parse response", e) - throw e + override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): ParsedResponse { + try { + return ParsedResponseImpl(*messages, api = api) + } catch (e: Exception) { + log.info("Failed to parse response", e) + throw e + } } - } - override fun withModel(model: ChatModels): ParsedActor = ParsedActor( - resultClass = resultClass, - prompt = prompt, - name = name, - model = model, - temperature = temperature, - parsingModel = parsingModel, - ) + override fun withModel(model: ChatModels): ParsedActor = ParsedActor( + resultClass = resultClass, + prompt = prompt, + name = name, + model = model, + temperature = temperature, + parsingModel = parsingModel, + ) - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ParsedActor::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ParsedActor::class.java) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt index 54074902..d0344fe2 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt @@ -11,14 +11,16 @@ open class SimpleActor( name: String? = null, model: ChatModels, temperature: Double = 0.3, -) : BaseActor,String>( +) : BaseActor, String>( prompt = prompt, name = name, model = model, temperature = temperature, ) { - override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): String = response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") + override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): String = + response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") + override fun chatMessages(questions: List) = arrayOf( ApiModel.ChatMessage( role = ApiModel.Role.system, diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt index 94a6738b..f8504482 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt @@ -10,54 +10,54 @@ import com.simiacryptus.jopenai.models.ChatModels import com.simiacryptus.jopenai.util.ClientUtil.toContentList open class TextToSpeechActor( - name: String? = null, - val audioModel: AudioModels = AudioModels.TTS_HD, - val voice: String = "alloy", - val speed: Double = 1.0, - val models: ChatModels, + name: String? = null, + val audioModel: AudioModels = AudioModels.TTS_HD, + val voice: String = "alloy", + val speed: Double = 1.0, + val models: ChatModels, ) : BaseActor, SpeechResponse>( - prompt = "", - name = name, - model = models, + prompt = "", + name = name, + model = models, ) { - override fun chatMessages(questions: List) = questions.map { - ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - }.toTypedArray() - - inner class SpeechResponseImpl( - val text: String, - private val api: API - ) : SpeechResponse { - private val _image: ByteArray? by lazy { render(text, api) } - override val mp3data: ByteArray? get() = _image - } - - open fun render( - text: String, - api: API, - ): ByteArray = (api as OpenAIClient).createSpeech( - ApiModel.SpeechRequest( - input = text, - model = audioModel.modelName, - voice = voice, - speed = speed, - ) - ) ?: throw RuntimeException("No response") - - override fun respond(input: List, api: API, vararg messages: ChatMessage) = - SpeechResponseImpl( - messages.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, - api = api - ) - - - override fun withModel(model: ChatModels) = this + override fun chatMessages(questions: List) = questions.map { + ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() + ) + }.toTypedArray() + + inner class SpeechResponseImpl( + val text: String, + private val api: API + ) : SpeechResponse { + private val _image: ByteArray? by lazy { render(text, api) } + override val mp3data: ByteArray? get() = _image + } + + open fun render( + text: String, + api: API, + ): ByteArray = (api as OpenAIClient).createSpeech( + ApiModel.SpeechRequest( + input = text, + model = audioModel.modelName, + voice = voice, + speed = speed, + ) + ) ?: throw RuntimeException("No response") + + override fun respond(input: List, api: API, vararg messages: ChatMessage) = + SpeechResponseImpl( + messages.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, + api = api + ) + + + override fun withModel(model: ChatModels) = this } interface SpeechResponse { - val mp3data: ByteArray? + val mp3data: ByteArray? } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/ActorOptimization.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/ActorOptimization.kt index 1e6936f8..d9b7848d 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/ActorOptimization.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/ActorOptimization.kt @@ -34,7 +34,7 @@ open class ActorOptimization( val retries: Int = 3 ) - open fun ,T:Any> runGeneticGenerations( + open fun , T : Any> runGeneticGenerations( prompts: List, testCases: List, actorFactory: (String) -> BaseActor, @@ -48,12 +48,14 @@ open class ActorOptimization( val scores = topPrompts.map { prompt -> prompt to testCases.map { testCase -> val actor = actorFactory(prompt) - val answer = actor.respond(input = listOf(actor.prompt) as I, api = api, *(listOf( - ApiModel.ChatMessage( - role = ApiModel.Role.system, - content = actor.prompt.toContentList() - ), - ) + testCase.userMessages).toTypedArray()) + val answer = actor.respond( + input = listOf(actor.prompt) as I, api = api, *(listOf( + ApiModel.ChatMessage( + role = ApiModel.Role.system, + content = actor.prompt.toContentList() + ), + ) + testCase.userMessages).toTypedArray() + ) testCase.expectations.map { it.score(api, resultMapper(answer)) }.average() }.average() } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/Expectation.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/Expectation.kt index 655f039d..816e2684 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/Expectation.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/opt/Expectation.kt @@ -10,7 +10,8 @@ abstract class Expectation { private val log = LoggerFactory.getLogger(Expectation::class.java) } - open class VectorMatch(private val example: String, private val metric: DistanceType = DistanceType.Cosine) : Expectation() { + open class VectorMatch(private val example: String, private val metric: DistanceType = DistanceType.Cosine) : + Expectation() { override fun matches(api: OpenAIClient, response: String): Boolean { return true } @@ -43,6 +44,7 @@ abstract class Expectation { if (!critical) return true return _matches(response) } + override fun score(api: OpenAIClient, response: String): Double { return if (_matches(response)) 1.0 else 0.0 } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/CodingActorInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/CodingActorInterceptor.kt index 98a9bd8c..9e958d5e 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/CodingActorInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/CodingActorInterceptor.kt @@ -32,11 +32,11 @@ class CodingActorInterceptor( } - override fun respond( input: CodeRequest, api: API, vararg messages: ChatMessage) = inner.CodeResultImpl( - messages = messages, - input = input, - api = api as com.simiacryptus.jopenai.OpenAIClient, - givenCode = super.respond(input = input, api = api, *messages, ).code, + override fun respond(input: CodeRequest, api: API, vararg messages: ChatMessage) = inner.CodeResultImpl( + messages = messages, + input = input, + api = api as com.simiacryptus.jopenai.OpenAIClient, + givenCode = super.respond(input = input, api = api, *messages).code, ) override fun execute(prefix: String, code: String) = diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ImageActorInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ImageActorInterceptor.kt index 8ffe6a22..1a80a4ec 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ImageActorInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ImageActorInterceptor.kt @@ -23,10 +23,12 @@ class ImageActorInterceptor( vararg input: com.simiacryptus.jopenai.ApiModel.ChatMessage, model: OpenAIModel, api: API - ) = functionInterceptor.wrap(input.toList().toTypedArray(), model) { - messages: Array, + ) = functionInterceptor.wrap( + input.toList().toTypedArray(), + model + ) { messages: Array, model: OpenAIModel -> - inner.response(*messages, model = model, api = api) + inner.response(*messages, model = model, api = api) } override fun render(text: String, api: API): BufferedImage = functionInterceptor.wrap(text) { diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ParsedActorInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ParsedActorInterceptor.kt index 994c36bc..385d4f09 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ParsedActorInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/ParsedActorInterceptor.kt @@ -8,35 +8,39 @@ import com.simiacryptus.skyenet.core.util.FunctionWrapper import java.util.function.Function class ParsedActorInterceptor( - val inner: ParsedActor<*>, - private val functionInterceptor: FunctionWrapper, + val inner: ParsedActor<*>, + private val functionInterceptor: FunctionWrapper, ) : ParsedActor( - resultClass = inner.resultClass as Class, - exampleInstance = inner.exampleInstance, - prompt = inner.prompt, - name = inner.name, - model = inner.model, - temperature = inner.temperature, - parsingModel = inner.parsingModel, + resultClass = inner.resultClass as Class, + exampleInstance = inner.exampleInstance, + prompt = inner.prompt, + name = inner.name, + model = inner.model, + temperature = inner.temperature, + parsingModel = inner.parsingModel, ) { - override fun respond(input: List, api: API, vararg messages: com.simiacryptus.jopenai.ApiModel.ChatMessage, ) = - object : ParsedResponse(resultClass!!) { - private val parser: Function = getParser(api) + override fun respond( + input: List, + api: API, + vararg messages: com.simiacryptus.jopenai.ApiModel.ChatMessage, + ) = + object : ParsedResponse(resultClass!!) { + private val parser: Function = getParser(api) - private val _obj: Any by lazy { parse() } + private val _obj: Any by lazy { parse() } - private fun parse(): Any = functionInterceptor.inner.intercept(text, resultClass!!) { parser.apply(text) } - override val text get() = super@ParsedActorInterceptor.respond(input = input, api = api, *messages, ).text - override val obj get() = _obj - } + private fun parse(): Any = functionInterceptor.inner.intercept(text, resultClass!!) { parser.apply(text) } + override val text get() = super@ParsedActorInterceptor.respond(input = input, api = api, *messages).text + override val obj get() = _obj + } override fun response( vararg input: com.simiacryptus.jopenai.ApiModel.ChatMessage, model: OpenAIModel, api: API - ) = functionInterceptor.wrap(input.toList().toTypedArray(), model) { - messages, model -> inner.response(*messages, model = model, api = api) + ) = functionInterceptor.wrap(input.toList().toTypedArray(), model) { messages, model -> + inner.response(*messages, model = model, api = api) } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/SimpleActorInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/SimpleActorInterceptor.kt index 57d786a3..d0f3dd95 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/SimpleActorInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/SimpleActorInterceptor.kt @@ -19,10 +19,12 @@ class SimpleActorInterceptor( vararg input: com.simiacryptus.jopenai.ApiModel.ChatMessage, model: OpenAIModel, api: API - ) = functionInterceptor.wrap(input.toList().toTypedArray(), model) { - messages: Array, + ) = functionInterceptor.wrap( + input.toList().toTypedArray(), + model + ) { messages: Array, model: OpenAIModel -> - inner.response(*messages, model = model, api = api) + inner.response(*messages, model = model, api = api) } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/TextToSpeechActorInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/TextToSpeechActorInterceptor.kt index e405f939..3a783312 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/TextToSpeechActorInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/record/TextToSpeechActorInterceptor.kt @@ -10,23 +10,25 @@ class TextToSpeechActorInterceptor( val inner: TextToSpeechActor, private val functionInterceptor: FunctionWrapper, ) : TextToSpeechActor( - name = inner.name, - audioModel = inner.audioModel, - "alloy", - 1.0, - ChatModels.GPT35Turbo, + name = inner.name, + audioModel = inner.audioModel, + "alloy", + 1.0, + ChatModels.GPT35Turbo, ) { override fun response( vararg input: com.simiacryptus.jopenai.ApiModel.ChatMessage, model: OpenAIModel, api: API - ) = functionInterceptor.wrap(input.toList().toTypedArray(), model) { - messages: Array, + ) = functionInterceptor.wrap( + input.toList().toTypedArray(), + model + ) { messages: Array, model: OpenAIModel -> - inner.response(*messages, model = model, api = api) + inner.response(*messages, model = model, api = api) } - override fun render(text: String, api: API) : ByteArray = functionInterceptor.wrap(text) { + override fun render(text: String, api: API): ByteArray = functionInterceptor.wrap(text) { inner.render(it, api = api) } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ActorTestBase.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ActorTestBase.kt index e5d6bd83..cc204c44 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ActorTestBase.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ActorTestBase.kt @@ -9,7 +9,7 @@ import com.simiacryptus.skyenet.core.actors.opt.ActorOptimization import org.slf4j.LoggerFactory import org.slf4j.event.Level -abstract class ActorTestBase { +abstract class ActorTestBase { open val api = OpenAIClient(logLevel = Level.DEBUG) @@ -59,7 +59,7 @@ abstract class ActorTestBase { open fun answer(messages: Array): R = actor.respond( input = (messages.map { it.content?.first()?.text }) as I, - api=api, + api = api, *messages ) diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/CodingActorTestBase.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/CodingActorTestBase.kt index 61df37a4..6e311e7f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/CodingActorTestBase.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/CodingActorTestBase.kt @@ -1,10 +1,10 @@ package com.simiacryptus.skyenet.core.actors.test import com.simiacryptus.jopenai.models.ChatModels -import com.simiacryptus.skyenet.interpreter.Interpreter import com.simiacryptus.skyenet.core.actors.BaseActor import com.simiacryptus.skyenet.core.actors.CodingActor import com.simiacryptus.skyenet.core.actors.CodingActor.CodeResult +import com.simiacryptus.skyenet.interpreter.Interpreter import kotlin.reflect.KClass abstract class CodingActorTestBase : ActorTestBase() { @@ -15,7 +15,9 @@ abstract class CodingActorTestBase : ActorTestBase): String = (actor as CodingActor).details!! + override fun getPrompt(actor: BaseActor): String = + (actor as CodingActor).details!! + override fun resultMapper(result: CodeResult): String = result.code } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ImageActorTestBase.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ImageActorTestBase.kt index 90850bb7..b310cf64 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ImageActorTestBase.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/test/ImageActorTestBase.kt @@ -4,7 +4,7 @@ import com.simiacryptus.jopenai.models.ChatModels import com.simiacryptus.skyenet.core.actors.ImageActor import com.simiacryptus.skyenet.core.actors.ImageResponse -abstract class ImageActorTestBase() : ActorTestBase,ImageResponse>() { +abstract class ImageActorTestBase : ActorTestBase, ImageResponse>() { override fun actorFactory(prompt: String) = ImageActor( prompt = prompt, textModel = ChatModels.GPT35Turbo diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt index 434eb85f..997a04f9 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt @@ -17,314 +17,314 @@ import java.util.concurrent.atomic.AtomicInteger import kotlin.random.Random object ApplicationServices { - var isLocked: Boolean = false - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var authorizationManager: AuthorizationInterface = AuthorizationManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var userSettingsManager: UserSettingsInterface = UserSettingsManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var authenticationManager: AuthenticationInterface = AuthenticationManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var dataStorageFactory: (File) -> StorageInterface = { DataStorage(it) } - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var dataStorageRoot: File = File(System.getProperty("user.home"), ".skyenet") - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var clientManager: ClientManager = ClientManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - - var cloud: CloudPlatformInterface? = AwsPlatform.get() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - - - var seleniumFactory: ((ThreadPoolExecutor, Array?) -> Selenium)? = null - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var usageManager: UsageInterface = UsageManager(File(dataStorageRoot, ".skyenet/usage")) - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } + var isLocked: Boolean = false + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var authorizationManager: AuthorizationInterface = AuthorizationManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var userSettingsManager: UserSettingsInterface = UserSettingsManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var authenticationManager: AuthenticationInterface = AuthenticationManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var dataStorageFactory: (File) -> StorageInterface = { DataStorage(it) } + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var dataStorageRoot: File = File(System.getProperty("user.home"), ".skyenet") + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var clientManager: ClientManager = ClientManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + + var cloud: CloudPlatformInterface? = AwsPlatform.get() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + + + var seleniumFactory: ((ThreadPoolExecutor, Array?) -> Selenium)? = null + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var usageManager: UsageInterface = UsageManager(File(dataStorageRoot, ".skyenet/usage")) + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } } interface AuthenticationInterface { - fun getUser(accessToken: String?): User? + fun getUser(accessToken: String?): User? - fun putUser(accessToken: String, user: User): User - fun logout(accessToken: String, user: User) + fun putUser(accessToken: String, user: User): User + fun logout(accessToken: String, user: User) - companion object { - const val AUTH_COOKIE = "sessionId" - } + companion object { + const val AUTH_COOKIE = "sessionId" + } } interface AuthorizationInterface { - enum class OperationType { - Read, - Write, - Public, - Share, - Execute, - Delete, - Admin, - GlobalKey, - } - - fun isAuthorized( - applicationClass: Class<*>?, - user: User?, - operationType: OperationType, - ): Boolean + enum class OperationType { + Read, + Write, + Public, + Share, + Execute, + Delete, + Admin, + GlobalKey, + } + + fun isAuthorized( + applicationClass: Class<*>?, + user: User?, + operationType: OperationType, + ): Boolean } interface StorageInterface { - fun getJson( - user: User?, - session: Session, - filename: String, - clazz: Class - ): T? - - fun getMessages( - user: User?, - session: Session - ): LinkedHashMap - - fun getSessionDir( - user: User?, - session: Session - ): File - - fun getSessionName( - user: User?, - session: Session - ): String - - fun getSessionTime( - user: User?, - session: Session - ): Date? - - fun listSessions( - user: User? - ): List - - fun setJson( - user: User?, - session: Session, - filename: String, - settings: T - ): T - - fun updateMessage( - user: User?, - session: Session, - messageId: String, - value: String - ) - - fun listSessions(dir: File): List - fun userRoot(user: User?): File - fun deleteSession(user: User?, session: Session) - fun getMessageIds( - user: User?, - session: Session - ): List - - fun setMessageIds( - user: User?, - session: Session, - ids: List - ) - - companion object { - - fun validateSessionId( - session: Session - ) { - if (!session.sessionId.matches("""([GU]-)?\d{8}-[\w+-.]{4}""".toRegex())) { - throw IllegalArgumentException("Invalid session ID: $session") - } - } - - fun newGlobalID(): Session { - val yyyyMMdd = java.time.LocalDate.now().toString().replace("-", "") - //log.debug("New ID: $yyyyMMdd-$uuid") - return Session("G-$yyyyMMdd-${id2()}") - } + fun getJson( + user: User?, + session: Session, + filename: String, + clazz: Class + ): T? + + fun getMessages( + user: User?, + session: Session + ): LinkedHashMap + + fun getSessionDir( + user: User?, + session: Session + ): File + + fun getSessionName( + user: User?, + session: Session + ): String + + fun getSessionTime( + user: User?, + session: Session + ): Date? + + fun listSessions( + user: User? + ): List + + fun setJson( + user: User?, + session: Session, + filename: String, + settings: T + ): T + + fun updateMessage( + user: User?, + session: Session, + messageId: String, + value: String + ) - fun long64() = Base64.getEncoder().encodeToString(ByteBuffer.allocate(8).putLong(Random.nextLong()).array()) - .toString().replace("=", "").replace("/", ".").replace("+", "-") + fun listSessions(dir: File): List + fun userRoot(user: User?): File + fun deleteSession(user: User?, session: Session) + fun getMessageIds( + user: User?, + session: Session + ): List + + fun setMessageIds( + user: User?, + session: Session, + ids: List + ) - fun newUserID(): Session { - val yyyyMMdd = java.time.LocalDate.now().toString().replace("-", "") - //log.debug("New ID: $yyyyMMdd-$uuid") - return Session("U-$yyyyMMdd-${id2()}") - } + companion object { + + fun validateSessionId( + session: Session + ) { + if (!session.sessionId.matches("""([GU]-)?\d{8}-[\w+-.]{4}""".toRegex())) { + throw IllegalArgumentException("Invalid session ID: $session") + } + } + + fun newGlobalID(): Session { + val yyyyMMdd = java.time.LocalDate.now().toString().replace("-", "") + //log.debug("New ID: $yyyyMMdd-$uuid") + return Session("G-$yyyyMMdd-${id2()}") + } + + fun long64() = Base64.getEncoder().encodeToString(ByteBuffer.allocate(8).putLong(Random.nextLong()).array()) + .toString().replace("=", "").replace("/", ".").replace("+", "-") + + fun newUserID(): Session { + val yyyyMMdd = java.time.LocalDate.now().toString().replace("-", "") + //log.debug("New ID: $yyyyMMdd-$uuid") + return Session("U-$yyyyMMdd-${id2()}") + } + + private fun id2() = long64().filter { + when (it) { + in 'a'..'z' -> true + in 'A'..'Z' -> true + in '0'..'9' -> true + else -> false + } + }.take(4) + + fun parseSessionID(sessionID: String): Session { + val session = Session(sessionID) + validateSessionId(session) + return session + } - private fun id2() = long64().filter { - when (it) { - in 'a'..'z' -> true - in 'A'..'Z' -> true - in '0'..'9' -> true - else -> false - } - }.take(4) - - fun parseSessionID(sessionID: String): Session { - val session = Session(sessionID) - validateSessionId(session) - return session } - - } } interface UserSettingsInterface { - data class UserSettings( - val apiKeys: Map = APIProvider.values().associateWith { "" }, - val apiBase: Map = APIProvider.values().associateWith { it.base ?: "" }, - ) + data class UserSettings( + val apiKeys: Map = APIProvider.values().associateWith { "" }, + val apiBase: Map = APIProvider.values().associateWith { it.base ?: "" }, + ) - fun getUserSettings(user: User): UserSettings + fun getUserSettings(user: User): UserSettings - fun updateUserSettings(user: User, settings: UserSettings) + fun updateUserSettings(user: User, settings: UserSettings) } interface UsageInterface { - fun incrementUsage(session: Session, user: User?, model: OpenAIModel, tokens: ApiModel.Usage) = incrementUsage( - session, when (user) { - null -> null - else -> { - val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) - userSettings.apiKeys[if (model is ChatModels) { - model.provider - } else { - APIProvider.OpenAI - }] - } - }, model, tokens - ) - - fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) - - fun getUserUsageSummary(user: User): Map = getUserUsageSummary( - ApplicationServices.userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI]!! // TODO: Support other providers - ) - - fun getUserUsageSummary(apiKey: String): Map - - fun getSessionUsageSummary(session: Session): Map - fun clear() - - data class UsageKey( - val session: Session, - val apiKey: String?, - val model: OpenAIModel, - ) - - class UsageValues( - val inputTokens: AtomicInteger = AtomicInteger(), - val outputTokens: AtomicInteger = AtomicInteger(), - val cost: AtomicDouble = AtomicDouble(), - ) { - fun addAndGet(tokens: ApiModel.Usage) { - inputTokens.addAndGet(tokens.prompt_tokens) - outputTokens.addAndGet(tokens.completion_tokens) - cost.addAndGet(tokens.cost ?: 0.0) - } + fun incrementUsage(session: Session, user: User?, model: OpenAIModel, tokens: ApiModel.Usage) = incrementUsage( + session, when (user) { + null -> null + else -> { + val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) + userSettings.apiKeys[if (model is ChatModels) { + model.provider + } else { + APIProvider.OpenAI + }] + } + }, model, tokens + ) + + fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) - fun toUsage() = ApiModel.Usage( - prompt_tokens = inputTokens.get(), - completion_tokens = outputTokens.get(), - cost = cost.get() + fun getUserUsageSummary(user: User): Map = getUserUsageSummary( + ApplicationServices.userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI]!! // TODO: Support other providers ) - } - class UsageCounters( - val tokensPerModel: HashMap = HashMap(), - ) + fun getUserUsageSummary(apiKey: String): Map + + fun getSessionUsageSummary(session: Session): Map + fun clear() + + data class UsageKey( + val session: Session, + val apiKey: String?, + val model: OpenAIModel, + ) + + class UsageValues( + val inputTokens: AtomicInteger = AtomicInteger(), + val outputTokens: AtomicInteger = AtomicInteger(), + val cost: AtomicDouble = AtomicDouble(), + ) { + fun addAndGet(tokens: ApiModel.Usage) { + inputTokens.addAndGet(tokens.prompt_tokens) + outputTokens.addAndGet(tokens.completion_tokens) + cost.addAndGet(tokens.cost ?: 0.0) + } + + fun toUsage() = ApiModel.Usage( + prompt_tokens = inputTokens.get(), + completion_tokens = outputTokens.get(), + cost = cost.get() + ) + } + + class UsageCounters( + val tokensPerModel: HashMap = HashMap(), + ) } data class User( - @get:JsonProperty("email") val email: String, - @get:JsonProperty("name") val name: String? = null, - @get:JsonProperty("id") val id: String? = null, - @get:JsonProperty("picture") val picture: String? = null, - @get:JsonIgnore val credential: Any? = null, + @get:JsonProperty("email") val email: String, + @get:JsonProperty("name") val name: String? = null, + @get:JsonProperty("id") val id: String? = null, + @get:JsonProperty("picture") val picture: String? = null, + @get:JsonIgnore val credential: Any? = null, ) { - override fun toString() = email + override fun toString() = email - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false - other as User + other as User - return email == other.email - } + return email == other.email + } - override fun hashCode(): Int { - return email.hashCode() - } + override fun hashCode(): Int { + return email.hashCode() + } } data class Session( - internal val sessionId: String + internal val sessionId: String ) { - init { - StorageInterface.validateSessionId(this) - } + init { + StorageInterface.validateSessionId(this) + } - override fun toString() = sessionId - fun isGlobal(): Boolean = sessionId.startsWith("G-") + override fun toString() = sessionId + fun isGlobal(): Boolean = sessionId.startsWith("G-") } interface CloudPlatformInterface { - val shareBase: String - - fun upload( - path: String, - contentType: String, - bytes: ByteArray - ): String - - fun upload( - path: String, - contentType: String, - request: String - ): String - - fun encrypt(fileBytes: ByteArray, keyId: String): String? - fun decrypt(encryptedData: ByteArray): String + val shareBase: String + + fun upload( + path: String, + contentType: String, + bytes: ByteArray + ): String + + fun upload( + path: String, + contentType: String, + request: String + ): String + + fun encrypt(fileBytes: ByteArray, keyId: String): String? + fun decrypt(encryptedData: ByteArray): String } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt index 1297d02c..0c001593 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt @@ -13,78 +13,78 @@ import java.nio.charset.StandardCharsets import java.util.* open class AwsPlatform( - private val bucket: String = System.getProperty("share_bucket", "share.simiacrypt.us"), - override val shareBase: String = System.getProperty("share_base", "https://share.simiacrypt.us"), - private val region: Region? = Region.US_EAST_1 + private val bucket: String = System.getProperty("share_bucket", "share.simiacrypt.us"), + override val shareBase: String = System.getProperty("share_base", "https://share.simiacrypt.us"), + private val region: Region? = Region.US_EAST_1 ) : CloudPlatformInterface { - protected open val kmsClient: KmsClient by lazy { - KmsClient.builder().region(Region.US_EAST_1) - //.credentialsProvider(ProfileCredentialsProvider.create("data")) - .build() - } + protected open val kmsClient: KmsClient by lazy { + KmsClient.builder().region(Region.US_EAST_1) + //.credentialsProvider(ProfileCredentialsProvider.create("data")) + .build() + } - protected open val s3Client: S3Client by lazy { - S3Client.builder() - .region(region) - .build() - } + protected open val s3Client: S3Client by lazy { + S3Client.builder() + .region(region) + .build() + } - override fun upload( - path: String, - contentType: String, - bytes: ByteArray - ): String { - s3Client.putObject( - PutObjectRequest.builder() - .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) - .contentType(contentType) - .build(), - RequestBody.fromBytes(bytes) - ) - return "$shareBase/$path" - } + override fun upload( + path: String, + contentType: String, + bytes: ByteArray + ): String { + s3Client.putObject( + PutObjectRequest.builder() + .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) + .contentType(contentType) + .build(), + RequestBody.fromBytes(bytes) + ) + return "$shareBase/$path" + } - override fun upload( - path: String, - contentType: String, - request: String - ): String { - s3Client.putObject( - PutObjectRequest.builder() - .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) - .contentType(contentType) - .build(), - RequestBody.fromString(request) - ) - return "$shareBase/$path" - } + override fun upload( + path: String, + contentType: String, + request: String + ): String { + s3Client.putObject( + PutObjectRequest.builder() + .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) + .contentType(contentType) + .build(), + RequestBody.fromString(request) + ) + return "$shareBase/$path" + } - override fun encrypt(fileBytes: ByteArray, keyId: String): String? = - Base64.getEncoder().encodeToString( - kmsClient.encrypt( - EncryptRequest.builder() - .keyId(keyId) - .plaintext(SdkBytes.fromByteArray(fileBytes)) - .build() - ).ciphertextBlob().asByteArray() - ) + override fun encrypt(fileBytes: ByteArray, keyId: String): String? = + Base64.getEncoder().encodeToString( + kmsClient.encrypt( + EncryptRequest.builder() + .keyId(keyId) + .plaintext(SdkBytes.fromByteArray(fileBytes)) + .build() + ).ciphertextBlob().asByteArray() + ) - override fun decrypt(encryptedData: ByteArray): String = String( - kmsClient.decrypt( - DecryptRequest.builder() - .ciphertextBlob(SdkBytes.fromByteArray(Base64.getDecoder().decode(encryptedData))) - .build() - ).plaintext().asByteArray(), StandardCharsets.UTF_8 - ) + override fun decrypt(encryptedData: ByteArray): String = String( + kmsClient.decrypt( + DecryptRequest.builder() + .ciphertextBlob(SdkBytes.fromByteArray(Base64.getDecoder().decode(encryptedData))) + .build() + ).plaintext().asByteArray(), StandardCharsets.UTF_8 + ) - companion object { - val log = LoggerFactory.getLogger(AwsPlatform::class.java) - fun get() = try { - AwsPlatform() - } catch (e: Throwable) { - log.info("Error initializing AWS platform", e) - null + companion object { + val log = LoggerFactory.getLogger(AwsPlatform::class.java) + fun get() = try { + AwsPlatform() + } catch (e: Throwable) { + log.info("Error initializing AWS platform", e) + null + } } - } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt index b4ce14bd..2705471c 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt @@ -20,127 +20,126 @@ import java.util.concurrent.TimeUnit open class ClientManager { - private data class SessionKey(val session: Session, val user: User?) + private data class SessionKey(val session: Session, val user: User?) - private val clientCache = mutableMapOf() - private val poolCache = mutableMapOf() + private val clientCache = mutableMapOf() + private val poolCache = mutableMapOf() - fun getClient( - session: Session, - user: User?, - dataStorage: StorageInterface?, - ): OpenAIClient { - val key = SessionKey(session, user) - return if (null == dataStorage) clientCache[key] ?: throw IllegalStateException("No data storage") - else clientCache.getOrPut(key) { createClient(session, user, dataStorage)!! } - } + fun getClient( + session: Session, + user: User?, + dataStorage: StorageInterface?, + ): OpenAIClient { + val key = SessionKey(session, user) + return if (null == dataStorage) clientCache[key] ?: throw IllegalStateException("No data storage") + else clientCache.getOrPut(key) { createClient(session, user, dataStorage)!! } + } - protected open fun createPool(session: Session, user: User?, dataStorage: StorageInterface?) = - ThreadPoolExecutor( - 0, Integer.MAX_VALUE, - 500, TimeUnit.MILLISECONDS, - SynchronousQueue(), - RecordingThreadFactory(session, user) - ) + protected open fun createPool(session: Session, user: User?, dataStorage: StorageInterface?) = + ThreadPoolExecutor( + 0, Integer.MAX_VALUE, + 500, TimeUnit.MILLISECONDS, + SynchronousQueue(), + RecordingThreadFactory(session, user) + ) - fun getPool( - session: Session, - user: User?, - dataStorage: StorageInterface?, - ): ThreadPoolExecutor { - val key = SessionKey(session, user) - return poolCache.getOrPut(key) { - createPool(session, user, dataStorage) + fun getPool( + session: Session, + user: User?, + dataStorage: StorageInterface?, + ): ThreadPoolExecutor { + val key = SessionKey(session, user) + return poolCache.getOrPut(key) { + createPool(session, user, dataStorage) + } } - } - inner class RecordingThreadFactory( - session: Session, - user: User? - ) : ThreadFactory { - private val inner = ThreadFactoryBuilder().setNameFormat("Session $session; User $user; #%d").build() - val threads = mutableSetOf() - override fun newThread(r: Runnable): Thread { - inner.newThread(r).also { - threads.add(it) - return it - } + inner class RecordingThreadFactory( + session: Session, + user: User? + ) : ThreadFactory { + private val inner = ThreadFactoryBuilder().setNameFormat("Session $session; User $user; #%d").build() + val threads = mutableSetOf() + override fun newThread(r: Runnable): Thread { + inner.newThread(r).also { + threads.add(it) + return it + } + } } - } - protected open fun createClient( - session: Session, - user: User?, - dataStorage: StorageInterface?, - ): OpenAIClient? { - if (user != null) { - val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) - val logfile = dataStorage?.getSessionDir(user, session)?.resolve(".sys/$session/openai.log") - logfile?.parentFile?.mkdirs() - val userApi = - if (userSettings.apiKeys.isNotEmpty()) - MonitoredClient( - key = userSettings.apiKeys, - apiBase = userSettings.apiBase, - logfile = logfile, - session = session, - user = user, - workPool = getPool(session, user, dataStorage), - ) else null - if (userApi != null) return userApi + protected open fun createClient( + session: Session, + user: User?, + dataStorage: StorageInterface?, + ): OpenAIClient? { + if (user != null) { + val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) + val logfile = dataStorage?.getSessionDir(user, session)?.resolve(".sys/$session/openai.log") + logfile?.parentFile?.mkdirs() + val userApi = + if (userSettings.apiKeys.isNotEmpty()) + MonitoredClient( + key = userSettings.apiKeys, + apiBase = userSettings.apiBase, + logfile = logfile, + session = session, + user = user, + workPool = getPool(session, user, dataStorage), + ) else null + if (userApi != null) return userApi + } + val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( + null, user, OperationType.GlobalKey + ) + if (!canUseGlobalKey) throw RuntimeException("No API key") + val logfile = dataStorage?.getSessionDir(user, session)?.resolve(".sys/$session/openai.log") + logfile?.parentFile?.mkdirs() + return (if (ClientUtil.keyMap.isNotEmpty()) { + MonitoredClient( + key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, + logfile = logfile, + session = session, + user = user, + workPool = getPool(session, user, dataStorage), + ) + } else { + null + })!! } - val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( - null, user, OperationType.GlobalKey - ) - if (!canUseGlobalKey) throw RuntimeException("No API key") - val logfile = dataStorage?.getSessionDir(user, session)?.resolve(".sys/$session/openai.log") - logfile?.parentFile?.mkdirs() - return (if (ClientUtil.keyMap.isNotEmpty()) { - MonitoredClient( - key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, - logfile = logfile, - session = session, - user = user, - workPool = getPool(session, user, dataStorage), - ) - } else { - null - })!! - } - inner class MonitoredClient( - key: Map, - logfile: File?, - private val session: Session, - private val user: User?, - apiBase: Map = APIProvider.values().associate { it to (it.base ?: "") }, - scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool, - workPool: ThreadPoolExecutor = HttpClientManager.workPool, - client: CloseableHttpClient = HttpClientManager.client - ) : OpenAIClient( - key = key, - logLevel = Level.DEBUG, - logStreams = listOfNotNull( - logfile?.outputStream()?.buffered() - ).toMutableList(), - scheduledPool = scheduledPool, - workPool = workPool, - client = client, - apiBase = apiBase, - ) { - var budget = 2.00 - override fun authorize(request: HttpRequest, apiProvider: APIProvider) { - require(budget > 0.0) { "Budget Exceeded" } - super.authorize(request, ClientUtil.defaultApiProvider) - } + inner class MonitoredClient( + key: Map, + logfile: File?, + private val session: Session, + private val user: User?, + apiBase: Map = APIProvider.values().associate { it to (it.base ?: "") }, + scheduledPool: ListeningScheduledExecutorService = HttpClientManager.scheduledPool, + workPool: ThreadPoolExecutor = HttpClientManager.workPool, + client: CloseableHttpClient = HttpClientManager.client + ) : OpenAIClient( + key = key, + logLevel = Level.DEBUG, + logStreams = listOfNotNull( + logfile?.outputStream()?.buffered() + ).toMutableList(), + scheduledPool = scheduledPool, + workPool = workPool, + client = client, + apiBase = apiBase, + ) { + var budget = 2.00 + override fun authorize(request: HttpRequest, apiProvider: APIProvider) { + require(budget > 0.0) { "Budget Exceeded" } + super.authorize(request, ClientUtil.defaultApiProvider) + } - override fun onUsage(model: OpenAIModel?, tokens: ApiModel.Usage) { - ApplicationServices.usageManager.incrementUsage(session, user, model!!, tokens) - budget -= tokens.cost ?: 0.0 - super.onUsage(model, tokens) + override fun onUsage(model: OpenAIModel?, tokens: ApiModel.Usage) { + ApplicationServices.usageManager.incrementUsage(session, user, model!!, tokens) + budget -= tokens.cost ?: 0.0 + super.onUsage(model, tokens) + } } - } - companion object { - } + companion object } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt index 2ba8fc11..b5b1f66f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt @@ -50,7 +50,7 @@ open class DataStorage( user: User?, session: Session ): File { - if(sessionPaths.containsKey(session)) { + if (sessionPaths.containsKey(session)) { return sessionPaths[session]!! } validateSessionId(session) @@ -89,8 +89,8 @@ open class DataStorage( ): String { validateSessionId(session) val sessionDir = getSessionDir(user, session) - val settings = getJson(sessionDir, "settings.json", Map::class.java) ?: mapOf() - if(settings.containsKey("name")) return settings["name"] as String + val settings = getJson(sessionDir, "settings.json", Map::class.java) ?: mapOf() + if (settings.containsKey("name")) return settings["name"] as String val userMessage = messageFiles(session, sessionDir).entries.minByOrNull { it.key.lastModified() }?.value return if (null != userMessage) { @@ -109,8 +109,8 @@ open class DataStorage( ): List { validateSessionId(session) val sessionDir = getSessionDir(user, session) - val settings = getJson(sessionDir, ".sys/$session/internal.json", Map::class.java) ?: mapOf() - if(settings.containsKey("ids")) return settings["ids"].toString().split(",").toList() + val settings = getJson(sessionDir, ".sys/$session/internal.json", Map::class.java) ?: mapOf() + if (settings.containsKey("ids")) return settings["ids"].toString().split(",").toList() val ids = messageFiles(session, sessionDir).entries.sortedBy { it.key.lastModified() } .map { it.key.nameWithoutExtension }.toList() setJson(sessionDir, ".sys/$session/internal.json", settings.plus("ids" to ids.joinToString(","))) @@ -128,15 +128,15 @@ open class DataStorage( setJson(sessionDir, ".sys/$session/internal.json", settings.plus("ids" to ids.joinToString(","))) } -override fun getSessionTime( + override fun getSessionTime( user: User?, session: Session ): Date? { validateSessionId(session) val sessionDir = getSessionDir(user, session) - val settings = getJson(sessionDir, ".sys/$session/internal.json", Map::class.java) ?: mapOf() + val settings = getJson(sessionDir, ".sys/$session/internal.json", Map::class.java) ?: mapOf() val dateFormat = SimpleDateFormat.getDateTimeInstance() - if(settings.containsKey("time")) return dateFormat.parse(settings["time"] as String) + if (settings.containsKey("time")) return dateFormat.parse(settings["time"] as String) val file = messageFiles(session, sessionDir).entries.minByOrNull { it.key.lastModified() }?.key return if (null != file) { val date = Date(file.lastModified()) @@ -148,24 +148,25 @@ override fun getSessionTime( } } - private fun messageFiles(session: Session, sessionDir: File) = File(sessionDir, MESSAGE_DIR + "/$session").listFiles() - ?.filter { file -> file.isFile } - ?.map { messageFile -> - val fileText = messageFile.readText() - val split = fileText.split("

") - if (split.size < 2) { - //log.debug("Session {}: No messages", session) - messageFile to "" - } else { - val stringList = split[1].split("

") - if (stringList.isEmpty()) { + private fun messageFiles(session: Session, sessionDir: File) = + File(sessionDir, MESSAGE_DIR + "/$session").listFiles() + ?.filter { file -> file.isFile } + ?.map { messageFile -> + val fileText = messageFile.readText() + val split = fileText.split("

") + if (split.size < 2) { //log.debug("Session {}: No messages", session) messageFile to "" } else { - messageFile to stringList.first() + val stringList = split[1].split("

") + if (stringList.isEmpty()) { + //log.debug("Session {}: No messages", session) + messageFile to "" + } else { + messageFile to stringList.first() + } } - } - }?.filter { it.second.isNotEmpty() }?.toList()?.toMap() ?: mapOf() + }?.filter { it.second.isNotEmpty() }?.toList()?.toMap() ?: mapOf() override fun listSessions( user: User? @@ -197,7 +198,7 @@ override fun getSessionTime( ) { validateSessionId(session) val file = File(File(this.getSessionDir(user, session), MESSAGE_DIR + "/$session"), "$messageId.json") - if(!file.exists()) { + if (!file.exists()) { file.parentFile.mkdirs() addMessageID(user, session, messageId) } @@ -205,7 +206,7 @@ override fun getSessionTime( JsonUtil.objectMapper().writeValue(file, value) } - open protected fun addMessageID( + protected open fun addMessageID( user: User?, session: Session, messageId: String diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt index 6e6ac531..956633c2 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt @@ -14,183 +14,184 @@ import java.util.concurrent.TimeUnit open class UsageManager(val root: File) : UsageInterface { - private val scheduler = Executors.newSingleThreadScheduledExecutor() - private val txLogFile = File(root, "log.csv") - @Volatile - private var txLogFileWriter: FileWriter? - private val usagePerSession = ConcurrentHashMap() - private val sessionsByUser = ConcurrentHashMap>() - private val usersBySession = ConcurrentHashMap>() - - init { - txLogFile.parentFile.mkdirs() - loadFromLog(txLogFile) - txLogFileWriter = FileWriter(txLogFile, true) - scheduler.scheduleAtFixedRate({ saveCounters() }, 1, 1, TimeUnit.HOURS) - } - - @Suppress("MemberVisibilityCanBePrivate") - private fun loadFromLog(file: File) { - if (file.exists()) { - try { - file.readLines().forEach { line -> - val (sessionId, user, model, value, direction) = line.split(",") - val modelEnum = listOf( - ChatModels.values(), - CompletionModels.values(), - EditModels.values(), - EmbeddingModels.values() - ).flatMap { it.values }.find { model == it.modelName } - ?: throw RuntimeException("Unknown model $model") - when (direction) { - "input" -> incrementUsage( - Session(sessionId), - User(email = user), - modelEnum, - com.simiacryptus.jopenai.ApiModel.Usage(prompt_tokens = value.toInt()) - ) - - "output" -> incrementUsage( - Session(sessionId), - User(email = user), - modelEnum, - com.simiacryptus.jopenai.ApiModel.Usage(completion_tokens = value.toInt()) - ) - - "cost" -> incrementUsage( - session = Session(sessionId = sessionId), - user = User(email = user), - model = modelEnum, - tokens = com.simiacryptus.jopenai.ApiModel.Usage(cost = value.toDouble()) - ) - - else -> throw RuntimeException("Unknown direction $direction") - } - } - } catch (e: Exception) { - log.warn("Error loading log file", e) - } + private val scheduler = Executors.newSingleThreadScheduledExecutor() + private val txLogFile = File(root, "log.csv") + + @Volatile + private var txLogFileWriter: FileWriter? + private val usagePerSession = ConcurrentHashMap() + private val sessionsByUser = ConcurrentHashMap>() + private val usersBySession = ConcurrentHashMap>() + + init { + txLogFile.parentFile.mkdirs() + loadFromLog(txLogFile) + txLogFileWriter = FileWriter(txLogFile, true) + scheduler.scheduleAtFixedRate({ saveCounters() }, 1, 1, TimeUnit.HOURS) } - } - - @Suppress("MemberVisibilityCanBePrivate") - private fun writeCompactLog(file: File) { - FileWriter(file).use { writer -> - usagePerSession.forEach { (sessionId, usage) -> - val apiKey = usersBySession[sessionId]?.firstOrNull() - usage.tokensPerModel.forEach { (model, counter) -> - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.inputTokens.get()},input\n") - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.outputTokens.get()},output\n") - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.cost.get()},cost\n") + + @Suppress("MemberVisibilityCanBePrivate") + private fun loadFromLog(file: File) { + if (file.exists()) { + try { + file.readLines().forEach { line -> + val (sessionId, user, model, value, direction) = line.split(",") + val modelEnum = listOf( + ChatModels.values(), + CompletionModels.values(), + EditModels.values(), + EmbeddingModels.values() + ).flatMap { it.values }.find { model == it.modelName } + ?: throw RuntimeException("Unknown model $model") + when (direction) { + "input" -> incrementUsage( + Session(sessionId), + User(email = user), + modelEnum, + com.simiacryptus.jopenai.ApiModel.Usage(prompt_tokens = value.toInt()) + ) + + "output" -> incrementUsage( + Session(sessionId), + User(email = user), + modelEnum, + com.simiacryptus.jopenai.ApiModel.Usage(completion_tokens = value.toInt()) + ) + + "cost" -> incrementUsage( + session = Session(sessionId = sessionId), + user = User(email = user), + model = modelEnum, + tokens = com.simiacryptus.jopenai.ApiModel.Usage(cost = value.toDouble()) + ) + + else -> throw RuntimeException("Unknown direction $direction") + } + } + } catch (e: Exception) { + log.warn("Error loading log file", e) + } } - } - writer.flush() - } - } - - private fun saveCounters() { - txLogFileWriter = FileWriter(txLogFile, true) - val timedFile = File(txLogFile.absolutePath + "." + System.currentTimeMillis()) - writeCompactLog(timedFile) - val swapFile = File(txLogFile.absolutePath + ".old") - synchronized(txLogFile) { - try { - txLogFileWriter?.close() - } catch (e: Exception) { - log.warn("Error closing log file", e) - } - try { - txLogFile.renameTo(swapFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - try { - timedFile.renameTo(txLogFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - try { - swapFile.renameTo(timedFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - txLogFileWriter = FileWriter(txLogFile, true) } - val text = JsonUtil.toJson(usagePerSession) - File(root, "counters.json").writeText(text) - val toClean = txLogFile.parentFile.listFiles() - ?.filter { it.name.startsWith(txLogFile.name) && it.name != txLogFile.absolutePath } - ?.sortedBy { it.lastModified() } // oldest first - ?.dropLast(2) // keep 2 newest - ?.drop(2) // keep 2 oldest - toClean?.forEach { it.delete() } - } - - override fun incrementUsage( - session: Session, - apiKey: String?, - model: OpenAIModel, - tokens: com.simiacryptus.jopenai.ApiModel.Usage - ) { - usagePerSession.computeIfAbsent(session) { UsageCounters() } - .tokensPerModel.computeIfAbsent(UsageKey(session, apiKey, model)) { UsageValues() } - .addAndGet(tokens) - if (apiKey != null) { - sessionsByUser.computeIfAbsent(apiKey) { HashSet() }.add(session) + + @Suppress("MemberVisibilityCanBePrivate") + private fun writeCompactLog(file: File) { + FileWriter(file).use { writer -> + usagePerSession.forEach { (sessionId, usage) -> + val apiKey = usersBySession[sessionId]?.firstOrNull() + usage.tokensPerModel.forEach { (model, counter) -> + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.inputTokens.get()},input\n") + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.outputTokens.get()},output\n") + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.cost.get()},cost\n") + } + } + writer.flush() + } } - try { - val txLogFileWriter = txLogFileWriter - if (null != txLogFileWriter) { + + private fun saveCounters() { + txLogFileWriter = FileWriter(txLogFile, true) + val timedFile = File(txLogFile.absolutePath + "." + System.currentTimeMillis()) + writeCompactLog(timedFile) + val swapFile = File(txLogFile.absolutePath + ".old") synchronized(txLogFile) { - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.prompt_tokens},input\n") - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.completion_tokens},output\n") - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.cost},cost\n") - txLogFileWriter.flush() + try { + txLogFileWriter?.close() + } catch (e: Exception) { + log.warn("Error closing log file", e) + } + try { + txLogFile.renameTo(swapFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + try { + timedFile.renameTo(txLogFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + try { + swapFile.renameTo(timedFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + txLogFileWriter = FileWriter(txLogFile, true) } - } - } catch (e: Exception) { - log.warn("Error incrementing usage", e) + val text = JsonUtil.toJson(usagePerSession) + File(root, "counters.json").writeText(text) + val toClean = txLogFile.parentFile.listFiles() + ?.filter { it.name.startsWith(txLogFile.name) && it.name != txLogFile.absolutePath } + ?.sortedBy { it.lastModified() } // oldest first + ?.dropLast(2) // keep 2 newest + ?.drop(2) // keep 2 oldest + toClean?.forEach { it.delete() } + } + + override fun incrementUsage( + session: Session, + apiKey: String?, + model: OpenAIModel, + tokens: com.simiacryptus.jopenai.ApiModel.Usage + ) { + usagePerSession.computeIfAbsent(session) { UsageCounters() } + .tokensPerModel.computeIfAbsent(UsageKey(session, apiKey, model)) { UsageValues() } + .addAndGet(tokens) + if (apiKey != null) { + sessionsByUser.computeIfAbsent(apiKey) { HashSet() }.add(session) + } + try { + val txLogFileWriter = txLogFileWriter + if (null != txLogFileWriter) { + synchronized(txLogFile) { + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.prompt_tokens},input\n") + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.completion_tokens},output\n") + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.cost},cost\n") + txLogFileWriter.flush() + } + } + } catch (e: Exception) { + log.warn("Error incrementing usage", e) + } + } + + override fun getUserUsageSummary(apiKey: String): Map { + return sessionsByUser[apiKey]?.flatMap { sessionId -> + val usage = usagePerSession[sessionId] + usage?.tokensPerModel?.entries?.map { (model, counter) -> + model.model to counter.toUsage() + } ?: emptyList() + }?.groupBy { it.first }?.mapValues { + it.value.map { it.second }.reduce { a, b -> + com.simiacryptus.jopenai.ApiModel.Usage( + prompt_tokens = a.prompt_tokens + b.prompt_tokens, + completion_tokens = a.completion_tokens + b.completion_tokens, + cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) + ) + } + } ?: emptyMap() + } + + override fun getSessionUsageSummary(session: Session): Map = + usagePerSession[session]?.tokensPerModel?.entries?.map { (model, counter) -> + model.model to counter.toUsage() + }?.groupBy { it.first }?.mapValues { + it.value.map { it.second }.reduce { a, b -> + com.simiacryptus.jopenai.ApiModel.Usage( + prompt_tokens = a.prompt_tokens + b.prompt_tokens, + completion_tokens = a.completion_tokens + b.completion_tokens, + cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) + ) + } + } ?: emptyMap() + + override fun clear() { + usagePerSession.clear() + sessionsByUser.clear() + usersBySession.clear() + saveCounters() + } + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(UsageManager::class.java) } - } - - override fun getUserUsageSummary(apiKey: String): Map { - return sessionsByUser[apiKey]?.flatMap { sessionId -> - val usage = usagePerSession[sessionId] - usage?.tokensPerModel?.entries?.map { (model, counter) -> - model.model to counter.toUsage() - } ?: emptyList() - }?.groupBy { it.first }?.mapValues { - it.value.map { it.second }.reduce { a, b -> - com.simiacryptus.jopenai.ApiModel.Usage( - prompt_tokens = a.prompt_tokens + b.prompt_tokens, - completion_tokens = a.completion_tokens + b.completion_tokens, - cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) - ) - } - } ?: emptyMap() - } - - override fun getSessionUsageSummary(session: Session): Map = - usagePerSession[session]?.tokensPerModel?.entries?.map { (model, counter) -> - model.model to counter.toUsage() - }?.groupBy { it.first }?.mapValues { - it.value.map { it.second }.reduce { a, b -> - com.simiacryptus.jopenai.ApiModel.Usage( - prompt_tokens = a.prompt_tokens + b.prompt_tokens, - completion_tokens = a.completion_tokens + b.completion_tokens, - cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) - ) - } - } ?: emptyMap() - - override fun clear() { - usagePerSession.clear() - sessionsByUser.clear() - usersBySession.clear() - saveCounters() - } - - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(UsageManager::class.java) - } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt index 4df51b50..1da68fe1 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt @@ -1,4 +1,3 @@ - import com.simiacryptus.skyenet.core.platform.AuthenticationInterface import com.simiacryptus.skyenet.core.platform.User import org.junit.jupiter.api.Assertions.* @@ -6,39 +5,44 @@ import org.junit.jupiter.api.Test import java.util.* open class AuthenticationInterfaceTest( - private val authInterface: AuthenticationInterface + private val authInterface: AuthenticationInterface ) { - private val validAccessToken = UUID.randomUUID().toString() - private val newUser = User(email = "newuser@example.com", name = "Jane Smith", id = "2", picture = "http://example.com/newpicture.jpg") - - @Test - fun `getUser should return null when no user is associated with access token`() { - val user = authInterface.getUser(validAccessToken) - assertNull(user) - } - - @Test - fun `putUser should add a new user and return the user`() { - val returnedUser = authInterface.putUser(validAccessToken, newUser) - assertEquals(newUser, returnedUser) - } - - @Test - fun `getUser should return User after putUser is called`() { - authInterface.putUser(validAccessToken, newUser) - val user: User? = authInterface.getUser(validAccessToken) - assertNotNull(user) - assertEquals(newUser, user) - } - - @Test - fun `logout should remove the user associated with the access token`() { - authInterface.putUser(validAccessToken, newUser) - assertNotNull(authInterface.getUser(validAccessToken)) - - authInterface.logout(validAccessToken, newUser) - assertNull(authInterface.getUser(validAccessToken)) - } + private val validAccessToken = UUID.randomUUID().toString() + private val newUser = User( + email = "newuser@example.com", + name = "Jane Smith", + id = "2", + picture = "http://example.com/newpicture.jpg" + ) + + @Test + fun `getUser should return null when no user is associated with access token`() { + val user = authInterface.getUser(validAccessToken) + assertNull(user) + } + + @Test + fun `putUser should add a new user and return the user`() { + val returnedUser = authInterface.putUser(validAccessToken, newUser) + assertEquals(newUser, returnedUser) + } + + @Test + fun `getUser should return User after putUser is called`() { + authInterface.putUser(validAccessToken, newUser) + val user: User? = authInterface.getUser(validAccessToken) + assertNotNull(user) + assertEquals(newUser, user) + } + + @Test + fun `logout should remove the user associated with the access token`() { + authInterface.putUser(validAccessToken, newUser) + assertNotNull(authInterface.getUser(validAccessToken)) + + authInterface.logout(validAccessToken, newUser) + assertNull(authInterface.getUser(validAccessToken)) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt index 2b64a491..bf271eb8 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt @@ -1,4 +1,3 @@ - import com.simiacryptus.skyenet.core.platform.AuthorizationInterface import com.simiacryptus.skyenet.core.platform.User import org.junit.jupiter.api.Assertions.assertFalse @@ -6,14 +5,19 @@ import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test open class AuthorizationInterfaceTest( - private val authInterface: AuthorizationInterface + private val authInterface: AuthorizationInterface ) { - open val user = User(email = "newuser@example.com", name = "Jane Smith", id = "2", picture = "http://example.com/newpicture.jpg") + open val user = User( + email = "newuser@example.com", + name = "Jane Smith", + id = "2", + picture = "http://example.com/newpicture.jpg" + ) - @Test - fun `newUser has admin`() { - assertFalse(authInterface.isAuthorized(this.javaClass, user, AuthorizationInterface.OperationType.Admin)) - } + @Test + fun `newUser has admin`() { + assertFalse(authInterface.isAuthorized(this.javaClass, user, AuthorizationInterface.OperationType.Admin)) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt index 5f98942d..e00d7550 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt @@ -10,25 +10,25 @@ import org.junit.jupiter.api.Test import kotlin.random.Random abstract class UsageTest(private val impl: UsageInterface) { - private val testUser = User( - email = "test@example.com", - name = "Test User", - id = Random.nextInt().toString() - ) - - @Test - fun `incrementUsage should increment usage for session`() { - val model = ChatModels.GPT35Turbo - val session = StorageInterface.newGlobalID() - val usage = ApiModel.Usage( - prompt_tokens = 10, - completion_tokens = 20, - cost = 30.0, + private val testUser = User( + email = "test@example.com", + name = "Test User", + id = Random.nextInt().toString() ) - impl.incrementUsage(session, testUser, model, usage) - val usageSummary = impl.getSessionUsageSummary(session) - Assertions.assertEquals(usage, usageSummary[model]) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertEquals(usage, userUsageSummary[model]) - } + + @Test + fun `incrementUsage should increment usage for session`() { + val model = ChatModels.GPT35Turbo + val session = StorageInterface.newGlobalID() + val usage = ApiModel.Usage( + prompt_tokens = 10, + completion_tokens = 20, + cost = 30.0, + ) + impl.incrementUsage(session, testUser, model, usage) + val usageSummary = impl.getSessionUsageSummary(session) + Assertions.assertEquals(usage, usageSummary[model]) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertEquals(usage, userUsageSummary[model]) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt index d766266a..42b7045e 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt @@ -10,36 +10,36 @@ import java.util.* abstract class UserSettingsTest(private val userSettings: UserSettingsInterface) { - @Test - fun `updateUserSettings should store custom settings for user`() { - val id = UUID.randomUUID().toString() - val testUser = User( - email = "$id@example.com", - name = "Test User", - id = id - ) - - val newSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "12345")) - userSettings.updateUserSettings(testUser, newSettings) - val settings = userSettings.getUserSettings(testUser) - Assertions.assertEquals("12345", settings.apiKeys[APIProvider.OpenAI]) - } - - @Test - fun `getUserSettings should return updated settings after updateUserSettings is called`() { - val id = UUID.randomUUID().toString() - val testUser = User( - email = "$id@example.com", - name = "Test User", - id = id - ) - val initialSettings = userSettings.getUserSettings(testUser) - Assertions.assertEquals("", initialSettings.apiKeys[APIProvider.OpenAI]) - - val updatedSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "67890")) - userSettings.updateUserSettings(testUser, updatedSettings) - - val settingsAfterUpdate = userSettings.getUserSettings(testUser) - Assertions.assertEquals("67890", settingsAfterUpdate.apiKeys[APIProvider.OpenAI]) - } + @Test + fun `updateUserSettings should store custom settings for user`() { + val id = UUID.randomUUID().toString() + val testUser = User( + email = "$id@example.com", + name = "Test User", + id = id + ) + + val newSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "12345")) + userSettings.updateUserSettings(testUser, newSettings) + val settings = userSettings.getUserSettings(testUser) + Assertions.assertEquals("12345", settings.apiKeys[APIProvider.OpenAI]) + } + + @Test + fun `getUserSettings should return updated settings after updateUserSettings is called`() { + val id = UUID.randomUUID().toString() + val testUser = User( + email = "$id@example.com", + name = "Test User", + id = id + ) + val initialSettings = userSettings.getUserSettings(testUser) + Assertions.assertEquals("", initialSettings.apiKeys[APIProvider.OpenAI]) + + val updatedSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "67890")) + userSettings.updateUserSettings(testUser, updatedSettings) + + val settingsAfterUpdate = userSettings.getUserSettings(testUser) + Assertions.assertEquals("67890", settingsAfterUpdate.apiKeys[APIProvider.OpenAI]) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt index 0feec9cb..3c8776ac 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt @@ -3,444 +3,371 @@ package com.simiacryptus.skyenet.core.util import org.objectweb.asm.* import org.objectweb.asm.signature.SignatureReader import org.objectweb.asm.signature.SignatureVisitor -import java.io.File import java.util.jar.JarFile object ClasspathRelationships { - sealed class Relation { - open val from_method : String = "" - open val to_method : String = "" - } - data object INHERITANCE : Relation() // When a class extends another class - data object INTERFACE_IMPLEMENTATION : Relation() // When a class implements an interface - data class FIELD_TYPE(override val from_method: String) : Relation() // When a class has a field of another class type - data class METHOD_PARAMETER( - override val from_method: String - ) : Relation() // When a class has a method that takes another class as a parameter - - data object METHOD_RETURN_TYPE : Relation() // When a class has a method that returns another class - data class LOCAL_VARIABLE(override val from_method: String) : - Relation() // When a method within a class declares a local variable of another class - - data class EXCEPTION_TYPE(override val from_method: String) : - Relation() // When a method declares that it throws an exception of another class - - data class ANNOTATION(override val from_method: String) : - Relation() // When a class, method, or field is annotated with another class (annotation) - - data class INSTANCE_CREATION(override val from_method: String) : Relation() // When a class creates an instance of another class - data class METHOD_REFERENCE( - override val from_method: String, - override val to_method: String - ) : Relation() // When a method references another class's method - data class METHOD_SIGNATURE( - override val from_method: String, - override val to_method: String - ) : Relation() // When a method signature references another class - - data class FIELD_REFERENCE(override val from_method: String) : Relation() // When a method references another class's field - data class DYNAMIC_BINDING(override val from_method: String) : - Relation() // When a class uses dynamic binding (e.g., invoke dynamic) related to another class - - data object OUTER_CLASS : Relation() // When a class references its outer class - - data object UNKNOWN : Relation() // A fallback for unknown or unclassified dependencies - - - class DependencyClassVisitor( - val dependencies: MutableMap> = mutableMapOf(), - var access: Int = 0, - var methods: MutableMap = mutableMapOf(), - ) : ClassVisitor(Opcodes.ASM9) { - - override fun visit( - version: Int, - access: Int, - name: String, - signature: String?, - superName: String?, - interfaces: Array? - ) { - this.access = access - // Add superclass dependency - superName?.let { addDep(it, INHERITANCE) } - // Add interface dependencies - interfaces?.forEach { addDep(it, INTERFACE_IMPLEMENTATION) } - visitSignature(name, signature) - super.visit(version, access, name, signature, superName, interfaces) + sealed class Relation { + open val from_method: String = "" + open val to_method: String = "" } - override fun visitField( - access: Int, - name: String?, - desc: String?, - signature: String?, - value: Any? - ): FieldVisitor? { - visitSignature(name, signature) - // Add field type dependency - addType(desc, FIELD_TYPE(from_method = "")) - return DependencyFieldVisitor(dependencies) - } + data object INHERITANCE : Relation() // When a class extends another class + data object INTERFACE_IMPLEMENTATION : Relation() // When a class implements an interface + data class FIELD_TYPE(override val from_method: String) : + Relation() // When a class has a field of another class type + + data class METHOD_PARAMETER( + override val from_method: String + ) : Relation() // When a class has a method that takes another class as a parameter + + data object METHOD_RETURN_TYPE : Relation() // When a class has a method that returns another class + data class LOCAL_VARIABLE(override val from_method: String) : + Relation() // When a method within a class declares a local variable of another class + + data class EXCEPTION_TYPE(override val from_method: String) : + Relation() // When a method declares that it throws an exception of another class + + data class ANNOTATION(override val from_method: String) : + Relation() // When a class, method, or field is annotated with another class (annotation) + + data class INSTANCE_CREATION(override val from_method: String) : + Relation() // When a class creates an instance of another class + + data class METHOD_REFERENCE( + override val from_method: String, + override val to_method: String + ) : Relation() // When a method references another class's method + + data class METHOD_SIGNATURE( + override val from_method: String, + override val to_method: String + ) : Relation() // When a method signature references another class + + data class FIELD_REFERENCE(override val from_method: String) : + Relation() // When a method references another class's field + + data class DYNAMIC_BINDING(override val from_method: String) : + Relation() // When a class uses dynamic binding (e.g., invoke dynamic) related to another class + + + class DependencyClassVisitor( + val dependencies: MutableMap> = mutableMapOf(), + var access: Int = 0, + var methods: MutableMap = mutableMapOf(), + ) : ClassVisitor(Opcodes.ASM9) { + + override fun visit( + version: Int, + access: Int, + name: String, + signature: String?, + superName: String?, + interfaces: Array? + ) { + this.access = access + // Add superclass dependency + superName?.let { addDep(it, INHERITANCE) } + // Add interface dependencies + interfaces?.forEach { addDep(it, INTERFACE_IMPLEMENTATION) } + visitSignature(name, signature) + super.visit(version, access, name, signature, superName, interfaces) + } - override fun visitMethod( - access: Int, - name: String?, - desc: String?, - signature: String?, - exceptions: Array? - ): MethodVisitor { - visitSignature(name, signature) - // Add method return type and parameter types dependencies - addMethodDescriptor(desc, METHOD_PARAMETER(from_method = name ?: ""), METHOD_RETURN_TYPE) - // Add exception types dependencies - exceptions?.forEach { addDep(it, EXCEPTION_TYPE(from_method = name ?: "")) } - val methodVisitor = DependencyMethodVisitor(name ?: "", dependencies) - methods[methodVisitor.name] = methodVisitor - return methodVisitor - } + override fun visitField( + access: Int, + name: String?, + desc: String?, + signature: String?, + value: Any? + ): FieldVisitor { + visitSignature(name, signature) + // Add field type dependency + addType(desc, FIELD_TYPE(from_method = "")) + return DependencyFieldVisitor(dependencies) + } - private fun visitSignature(name: String?, signature: String?) { - // Check if the name indicates an inner class or property accessor - if (name?.contains("$") == true) { - // NOTE: This isn't a typically required dependency - // addDep(name.substringBefore("$"), OUTER_CLASS) - } - if (name?.contains("baseClassLoader") == true) { - signature?.let { - val signatureReader = SignatureReader(it) - signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { - override fun visitClassType(name: String?) { - name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } + override fun visitMethod( + access: Int, + name: String?, + desc: String?, + signature: String?, + exceptions: Array? + ): MethodVisitor { + visitSignature(name, signature) + // Add method return type and parameter types dependencies + addMethodDescriptor(desc, METHOD_PARAMETER(from_method = name ?: ""), METHOD_RETURN_TYPE) + // Add exception types dependencies + exceptions?.forEach { addDep(it, EXCEPTION_TYPE(from_method = name ?: "")) } + val methodVisitor = DependencyMethodVisitor(name ?: "", dependencies) + methods[methodVisitor.name] = methodVisitor + return methodVisitor + } + + private fun visitSignature(name: String?, signature: String?) { + // Check if the name indicates an inner class or property accessor + if (name?.contains("$") == true) { + // NOTE: This isn't a typically required dependency + // addDep(name.substringBefore("$"), OUTER_CLASS) + } + if (name?.contains("baseClassLoader") == true) { + signature?.let { + val signatureReader = SignatureReader(it) + signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { + override fun visitClassType(name: String?) { + name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } + } + }) + } + return + } + signature?.let { + val signatureReader = SignatureReader(it) + signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { + override fun visitClassType(name: String?) { + name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } + } + }) } - }) } - return - } - signature?.let { - val signatureReader = SignatureReader(it) - signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { - override fun visitClassType(name: String?) { - name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } - } - }) - } - } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - // Add annotation type dependency - addType(descriptor, ANNOTATION(from_method = "")) - return super.visitAnnotation(descriptor, visible) - } + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + // Add annotation type dependency + addType(descriptor, ANNOTATION(from_method = "")) + return super.visitAnnotation(descriptor, visible) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) + } - private fun addType(type: String?, relationType: Relation) { - type?.let { - val typeName = Type.getType(it).className - addDep(typeName, relationType) - } - } + private fun addType(type: String?, relationType: Relation) { + type?.let { + val typeName = Type.getType(it).className + addDep(typeName, relationType) + } + } - private fun addMethodDescriptor( - descriptor: String?, - paramRelationType: Relation, - returnRelationType: Relation - ) { - descriptor?.let { - val methodType = Type.getMethodType(it) - // Add return type dependency - addType(methodType.returnType.descriptor, returnRelationType) - // Add parameter types dependencies - methodType.argumentTypes.forEach { argType -> - addType(argType.descriptor, paramRelationType) + private fun addMethodDescriptor( + descriptor: String?, + paramRelationType: Relation, + returnRelationType: Relation + ) { + descriptor?.let { + val methodType = Type.getMethodType(it) + // Add return type dependency + addType(methodType.returnType.descriptor, returnRelationType) + // Add parameter types dependencies + methodType.argumentTypes.forEach { argType -> + addType(argType.descriptor, paramRelationType) + } + } } - } + } - } + class DependencyFieldVisitor( + val dependencies: MutableMap> + ) : FieldVisitor(Opcodes.ASM9) { - class DependencyFieldVisitor( - val dependencies: MutableMap> - ) : FieldVisitor(Opcodes.ASM9) { + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + descriptor?.let { addType(it, ANNOTATION(from_method = "")) } + return super.visitAnnotation(descriptor, visible) + } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - descriptor?.let { addType(it, ANNOTATION(from_method = "")) } - return super.visitAnnotation(descriptor, visible) - } + override fun visitAttribute(attribute: Attribute?) { + super.visitAttribute(attribute) + } - override fun visitAttribute(attribute: Attribute?) { - super.visitAttribute(attribute) - } + override fun visitTypeAnnotation( + typeRef: Int, + typePath: TypePath?, + descriptor: String?, + visible: Boolean + ): AnnotationVisitor? { + descriptor?.let { addType(it, ANNOTATION(from_method = "")) } + return super.visitTypeAnnotation(typeRef, typePath, descriptor, visible) + } - override fun visitTypeAnnotation( - typeRef: Int, - typePath: TypePath?, - descriptor: String?, - visible: Boolean - ): AnnotationVisitor? { - descriptor?.let { addType(it, ANNOTATION(from_method = "")) } - return super.visitTypeAnnotation(typeRef, typePath, descriptor, visible) - } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + private fun addType(type: String, relationType: Relation) { + addDep(getTypeName(type) ?: return, relationType) + } - private fun addType(type: String, relationType: Relation) { - addDep(getTypeName(type) ?: return, relationType) } - } - - class DependencyMethodVisitor( - val name: String, - val dependencies: MutableMap>, - var access: Int = 0, - ) : MethodVisitor(Opcodes.ASM9) { - - - override fun visitMethodInsn( - opcode: Int, - owner: String?, - name: String?, - descriptor: String?, - isInterface: Boolean - ) { - access = opcode - // Add method reference dependency - owner?.let { addDep(it, METHOD_REFERENCE(from_method = this.name, to_method = name ?: "")) } - // Add method descriptor dependencies (for parameter and return types) - descriptor?.let { addMethodDescriptor(it, METHOD_SIGNATURE(from_method = this.name, to_method = name ?: "")) } - super.visitMethodInsn(opcode, owner, name, descriptor, isInterface) - } + class DependencyMethodVisitor( + val name: String, + val dependencies: MutableMap>, + var access: Int = 0, + ) : MethodVisitor(Opcodes.ASM9) { + + + override fun visitMethodInsn( + opcode: Int, + owner: String?, + name: String?, + descriptor: String?, + isInterface: Boolean + ) { + access = opcode + // Add method reference dependency + owner?.let { addDep(it, METHOD_REFERENCE(from_method = this.name, to_method = name ?: "")) } + // Add method descriptor dependencies (for parameter and return types) + descriptor?.let { + addMethodDescriptor( + it, + METHOD_SIGNATURE(from_method = this.name, to_method = name ?: "") + ) + } + super.visitMethodInsn(opcode, owner, name, descriptor, isInterface) + } - override fun visitParameter(name: String?, access: Int) { - // Add method parameter type dependency - name?.let { addType(it, METHOD_PARAMETER(from_method = this.name)) } - super.visitParameter(name, access) - } + override fun visitParameter(name: String?, access: Int) { + // Add method parameter type dependency + name?.let { addType(it, METHOD_PARAMETER(from_method = this.name)) } + super.visitParameter(name, access) + } - override fun visitFieldInsn(opcode: Int, owner: String?, name: String?, descriptor: String?) { - // Add field reference dependency - owner?.let { addDep(it, FIELD_REFERENCE(from_method = this.name)) } - // Add field type dependency - descriptor?.let { addType(it, FIELD_TYPE(from_method = this.name)) } - super.visitFieldInsn(opcode, owner, name, descriptor) - } + override fun visitFieldInsn(opcode: Int, owner: String?, name: String?, descriptor: String?) { + // Add field reference dependency + owner?.let { addDep(it, FIELD_REFERENCE(from_method = this.name)) } + // Add field type dependency + descriptor?.let { addType(it, FIELD_TYPE(from_method = this.name)) } + super.visitFieldInsn(opcode, owner, name, descriptor) + } - override fun visitTypeInsn(opcode: Int, type: String?) { - // Add instance creation or local variable dependency based on opcode - type?.let { - val dependencyType = when (opcode) { - Opcodes.NEW -> INSTANCE_CREATION(from_method = this.name) - else -> LOCAL_VARIABLE(from_method = this.name) + override fun visitTypeInsn(opcode: Int, type: String?) { + // Add instance creation or local variable dependency based on opcode + type?.let { + val dependencyType = when (opcode) { + Opcodes.NEW -> INSTANCE_CREATION(from_method = this.name) + else -> LOCAL_VARIABLE(from_method = this.name) + } + addType(it, dependencyType) + } + super.visitTypeInsn(opcode, type) } - addType(it, dependencyType) - } - super.visitTypeInsn(opcode, type) - } - override fun visitLdcInsn(value: Any?) { - // Add class literal dependency - if (value is Type) { - addType(value.descriptor, LOCAL_VARIABLE(from_method = this.name)) - } - super.visitLdcInsn(value) - } + override fun visitLdcInsn(value: Any?) { + // Add class literal dependency + if (value is Type) { + addType(value.descriptor, LOCAL_VARIABLE(from_method = this.name)) + } + super.visitLdcInsn(value) + } - override fun visitMultiANewArrayInsn(descriptor: String?, numDimensions: Int) { - // Add local variable dependency for multi-dimensional arrays - descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } - super.visitMultiANewArrayInsn(descriptor, numDimensions) - } + override fun visitMultiANewArrayInsn(descriptor: String?, numDimensions: Int) { + // Add local variable dependency for multi-dimensional arrays + descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } + super.visitMultiANewArrayInsn(descriptor, numDimensions) + } - override fun visitInvokeDynamicInsn( - name: String?, - descriptor: String?, - bootstrapMethodHandle: Handle?, - vararg bootstrapMethodArguments: Any? - ) { - // Add dynamic binding dependency - descriptor?.let { addMethodDescriptor(it, DYNAMIC_BINDING(from_method = this.name)) } - super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, *bootstrapMethodArguments) - } + override fun visitInvokeDynamicInsn( + name: String?, + descriptor: String?, + bootstrapMethodHandle: Handle?, + vararg bootstrapMethodArguments: Any? + ) { + // Add dynamic binding dependency + descriptor?.let { addMethodDescriptor(it, DYNAMIC_BINDING(from_method = this.name)) } + super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, *bootstrapMethodArguments) + } - override fun visitLocalVariable( - name: String?, - descriptor: String?, - signature: String?, - start: Label?, - end: Label?, - index: Int - ) { - // Add local variable dependency - descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } - super.visitLocalVariable(name, descriptor, signature, start, end, index) - } + override fun visitLocalVariable( + name: String?, + descriptor: String?, + signature: String?, + start: Label?, + end: Label?, + index: Int + ) { + // Add local variable dependency + descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } + super.visitLocalVariable(name, descriptor, signature, start, end, index) + } - override fun visitTryCatchBlock(start: Label?, end: Label?, handler: Label?, type: String?) { - // Add exception type dependency - type?.let { addType(it, EXCEPTION_TYPE(from_method = this.name)) } - super.visitTryCatchBlock(start, end, handler, type) - } + override fun visitTryCatchBlock(start: Label?, end: Label?, handler: Label?, type: String?) { + // Add exception type dependency + type?.let { addType(it, EXCEPTION_TYPE(from_method = this.name)) } + super.visitTryCatchBlock(start, end, handler, type) + } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - // Add annotation type dependency - descriptor?.let { addType(it, ANNOTATION(from_method = this.name)) } - return super.visitAnnotation(descriptor, visible) - } + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + // Add annotation type dependency + descriptor?.let { addType(it, ANNOTATION(from_method = this.name)) } + return super.visitAnnotation(descriptor, visible) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) + } - private fun addType(type: String, relationType: Relation): Unit { - addDep(getTypeName(type) ?: return, relationType) - } + private fun addType(type: String, relationType: Relation): Unit { + addDep(getTypeName(type) ?: return, relationType) + } - private fun addMethodDescriptor( - descriptor: String, - relationType: Relation - ) { - val methodType = Type.getMethodType(descriptor) - // Add return type dependency - addType(methodType.returnType.descriptor, relationType) - // Add parameter types dependencies - methodType.argumentTypes.forEach { addType(it.descriptor, relationType) } + private fun addMethodDescriptor( + descriptor: String, + relationType: Relation + ) { + val methodType = Type.getMethodType(descriptor) + // Add return type dependency + addType(methodType.returnType.descriptor, relationType) + // Add parameter types dependencies + methodType.argumentTypes.forEach { addType(it.descriptor, relationType) } + } } - } - - private fun getTypeName(type: String): String? = try { - val name = when { - // For array types, get the class name - type.startsWith("L") && type.endsWith(";") -> getTypeName(type.substring(1, type.length - 1)) - // Handle the case where the descriptor appears to be a plain class name - !type.startsWith("[") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type.classToPath).className - // Handle the case where the descriptor is missing 'L' and ';' - type.contains("/") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type).className - // For primitive types, use the descriptor directly - type.length == 1 && "BCDFIJSZ".contains(type[0]) -> type - type.endsWith("$") -> type.substring(0, type.length - 1) - else -> Type.getType(type).className + + private fun getTypeName(type: String): String? = try { + val name = when { + // For array types, get the class name + type.startsWith("L") && type.endsWith(";") -> getTypeName(type.substring(1, type.length - 1)) + // Handle the case where the descriptor appears to be a plain class name + !type.startsWith("[") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type.classToPath).className + // Handle the case where the descriptor is missing 'L' and ';' + type.contains("/") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type).className + // For primitive types, use the descriptor directly + type.length == 1 && "BCDFIJSZ".contains(type[0]) -> type + type.endsWith("$") -> type.substring(0, type.length - 1) + else -> Type.getType(type).className + } + name + } catch (e: Exception) { + println("Error adding type: $type (${e.message})") + null } - name - } catch (e: Exception) { - println("Error adding type: $type (${e.message})") - null - } - val String.classToPath - get() = removeSuffix(".class").replace('.', '/') + val String.classToPath + get() = removeSuffix(".class").replace('.', '/') + + data class Reference( + val from: String, + val to: String, + val relation: Relation + ) - val String.jarFiles - get() = File(this).listFiles()?.filter { - it.isFile && it.name.endsWith(".jar") + + fun readJarClasses(jarPath: String) = JarFile(jarPath).use { jarFile -> + jarFile.entries().asSequence().filter { it.name.endsWith(".class") }.map { entry -> + val className = entry.name.replace('/', '.').removeSuffix(".class") + className to jarFile.getInputStream(entry)?.readBytes() + }.toMap() } - val CharSequence.symbolName - get() = - replace("""^[\[ILZBCFDJSV]+""".toRegex(), "") - .removeSuffix(";").replace('/', '.') - .removeSuffix(".class") - - data class Reference( - val from: String, - val to: String, - val relation: Relation - ) - - fun analyzeJar(jarPath: String) = analyzeJar(readJarClasses(jarPath)) - fun analyzeJar( - jar: Map, - ) = jar.flatMap { (className, classData) -> - val dependencyClassVisitor = DependencyClassVisitor() - ClassReader(classData).accept(dependencyClassVisitor, 0) - val dependencies = dependencyClassVisitor.dependencies - dependencies.flatMap { (to, dependencies) -> - dependencies.map { Reference(className, to, it) } + + fun readJarFiles(jarPath: String) = JarFile(jarPath).use { jarFile -> + jarFile.entries().asSequence().map { it.name }.toList().toTypedArray() } - } - - fun classAccessMap(jarPath: String) = classAccessMap(readJarClasses(jarPath)) - - fun classAccessMap( - jar: Map, - ): Map = jar.flatMap { (className, classData) -> - val dependencyClassVisitor = DependencyClassVisitor() - ClassReader(classData).accept(dependencyClassVisitor, 0) - val methodData = - dependencyClassVisitor.methods.mapValues { it.value.access }.entries.map { it.key to it.value }.toMap() - listOf(className to dependencyClassVisitor.access) + methodData.map { className + "::" + it.key to it.value } - }.toMap() - - - fun readJarClasses(jarPath: String) = JarFile(jarPath).use { jarFile -> - jarFile.entries().asSequence().filter { it.name.endsWith(".class") }.map { entry -> - val className = entry.name.replace('/', '.').removeSuffix(".class") - className to jarFile.getInputStream(entry)?.readBytes() - }.toMap() - } - - fun readJarFiles(jarPath: String) = JarFile(jarPath).use { jarFile -> - jarFile.entries().asSequence().map { it.name }.toList().toTypedArray() - } - - fun upstream( - dependencies: List, - className: String, - buffer: MutableSet = mutableSetOf(className) - ) = upstream(upstreamMap(dependencies), className, buffer) - - fun upstreamMap(dependencies: List) = - dependencies.groupBy { it.to } - - fun upstream( - dependencies: Map>, - className: String, - buffer: MutableSet = mutableSetOf(className) - ): Set { - val required = (dependencies[className] ?: listOf()) - .map { it.from } - .filter { className != it } - .filter { !buffer.contains(it) } - .filter { it.isNotBlank() } - .toTypedArray() - synchronized(buffer) { buffer.addAll(required) } - required.toList().parallelStream().forEach { upstream(dependencies, it, buffer).stream() } - return buffer - } - - fun downstream( - dependencies: Map>, - className: String, - buffer: MutableSet = mutableSetOf(className) - ): Set { - val required = (dependencies[className] ?: listOf()) - .map { it.to } - .filter { className != it } - .filter { !buffer.contains(it) } - .filter { it.isNotBlank() } - .toTypedArray() - synchronized(buffer) { buffer.addAll(required) } - required.toList().parallelStream().forEach { downstream(dependencies, it, buffer).stream() } - return buffer - } - - fun downstream( - dependencies: List, - className: String, - buffer: MutableSet = mutableSetOf(className) - ) = downstream(downstreamMap(dependencies), className, buffer) - - fun downstreamMap(dependencies: List) = - dependencies.groupBy { it.from } + + fun downstreamMap(dependencies: List) = + dependencies.groupBy { it.from } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt index 5331c06a..6b94ac26 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt @@ -18,7 +18,7 @@ import java.util.concurrent.atomic.AtomicInteger @Suppress("unused") open class Ears( val api: OpenAIClient, - private val secondsPerAudioPacket : Double = 0.25, + private val secondsPerAudioPacket: Double = 0.25, ) { interface CommandRecognizer { @@ -70,7 +70,7 @@ open class Ears( log.info("Command recognized: ${result.command}") commandsProcessed.incrementAndGet() buffer.clear() - if(null != result.command) commandHandler(result.command) + if (null != result.command) commandHandler(result.command) } } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt index e7ead196..b2167484 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt @@ -14,19 +14,34 @@ import java.util.concurrent.atomic.AtomicInteger import javax.imageio.ImageIO class FunctionWrapper(val inner: FunctionInterceptor) : FunctionInterceptor { - inline fun wrap(crossinline fn: () -> T) = inner.intercept(T::class.java) { fn() } - inline fun wrap(p: P, crossinline fn: (P) -> T) = inner.intercept(p, T::class.java) { fn(it) } - inline fun wrap(p1: P1, p2: P2, crossinline fn: (P1, P2) -> T) = + inline fun wrap(crossinline fn: () -> T) = inner.intercept(T::class.java) { fn() } + inline fun

wrap(p: P, crossinline fn: (P) -> T) = + inner.intercept(p, T::class.java) { fn(it) } + + inline fun wrap(p1: P1, p2: P2, crossinline fn: (P1, P2) -> T) = inner.intercept(p1, p2, T::class.java) { p1, p2 -> fn(p1, p2) } - inline fun wrap(p1: P1, p2: P2, p3: P3, crossinline fn: (P1, P2, P3) -> T) = + + inline fun wrap( + p1: P1, + p2: P2, + p3: P3, + crossinline fn: (P1, P2, P3) -> T + ) = inner.intercept(p1, p2, p3, T::class.java) { p1, p2, p3 -> fn(p1, p2, p3) } - inline fun wrap(p1: P1, p2: P2, p3: P3, p4: P4, crossinline fn: (P1, P2, P3, P4) -> T) = + inline fun wrap( + p1: P1, + p2: P2, + p3: P3, + p4: P4, + crossinline fn: (P1, P2, P3, P4) -> T + ) = inner.intercept(p1, p2, p3, p4, T::class.java) { p1, p2, p3, p4 -> fn(p1, p2, p3, p4) } override fun intercept(returnClazz: Class, fn: () -> T) = inner.intercept(returnClazz, fn) - override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = inner.intercept(params, returnClazz, fn) + override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = + inner.intercept(params, returnClazz, fn) override fun intercept( p1: P1, @@ -37,21 +52,34 @@ class FunctionWrapper(val inner: FunctionInterceptor) : FunctionInterceptor { } interface FunctionInterceptor { - fun intercept(returnClazz: Class, fn: () -> T) = fn() - fun intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) - fun intercept(p1: P1, p2: P2, returnClazz: Class, fn: (P1, P2) -> T) = + fun intercept(returnClazz: Class, fn: () -> T) = fn() + fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) + fun intercept(p1: P1, p2: P2, returnClazz: Class, fn: (P1, P2) -> T) = intercept(listOf(p1, p2), returnClazz) { @Suppress("UNCHECKED_CAST") fn(it[0] as P1, it[1] as P2) } - fun intercept(p1: P1, p2: P2, p3: P3, returnClazz: Class, fn: (P1, P2, P3) -> T) = + fun intercept( + p1: P1, + p2: P2, + p3: P3, + returnClazz: Class, + fn: (P1, P2, P3) -> T + ) = intercept(listOf(p1, p2, p3), returnClazz) { @Suppress("UNCHECKED_CAST") fn(it[0] as P1, it[1] as P2, it[2] as P3) } - fun intercept(p1: P1, p2: P2, p3: P3, p4: P4, returnClazz: Class, fn: (P1, P2, P3, P4) -> T) = + fun intercept( + p1: P1, + p2: P2, + p3: P3, + p4: P4, + returnClazz: Class, + fn: (P1, P2, P3, P4) -> T + ) = intercept(listOf(p1, p2, p3, p4), returnClazz) { @Suppress("UNCHECKED_CAST") fn(it[0] as P1, it[1] as P2, it[2] as P3, it[3] as P4) @@ -59,13 +87,13 @@ interface FunctionInterceptor { } class NoopFunctionInterceptor : FunctionInterceptor { - override fun intercept(returnClazz: Class, fn: () -> T) = fn() - override fun intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) + override fun intercept(returnClazz: Class, fn: () -> T) = fn() + override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) } class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { private val baseDirectory = baseDir.apply { - if(exists()) { + if (exists()) { throw IllegalStateException("File already exists: $this") } mkdirs() @@ -80,7 +108,7 @@ class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { val dir = operationDir() try { val result = fn() - if(result is BufferedImage) { + if (result is BufferedImage) { ImageIO.write(result, "png", File(dir, "output.png")) } else { File(dir, "output.json").writeText(JsonUtil.toJson(result)) @@ -101,7 +129,7 @@ class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { File(dir, "input.json").writeText(JsonUtil.toJson(params)) try { val result = fn(params) - if(result is BufferedImage) { + if (result is BufferedImage) { ImageIO.write(result, "png", File(dir, "output.png")) } else { File(dir, "output.json").writeText(JsonUtil.toJson(result)) @@ -119,7 +147,8 @@ class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { private fun operationDir(): File { val id = sequenceId.incrementAndGet().toString().padStart(3, '0') - val yyyyMMddHHmmss = java.time.format.DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(java.time.LocalDateTime.now()) + val yyyyMMddHHmmss = + java.time.format.DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(java.time.LocalDateTime.now()) val internalClassList = listOf( java.lang.Thread::class.java, JsonFunctionRecorder::class.java, @@ -134,7 +163,7 @@ class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { .firstOrNull() val methodName = caller?.methodName ?: "unknown" val file = File(baseDirectory, "$id-$yyyyMMddHHmmss-$methodName") - if(file.exists()) { + if (file.exists()) { throw IllegalStateException("File already exists: $file") } file.mkdirs() @@ -147,8 +176,6 @@ class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { } - - fun getModel(modelName: String?): OpenAIModel? = ChatModels.values().values.find { it.modelName == modelName } ?: EmbeddingModels.values().values.find { it.modelName == modelName } ?: ImageModels.values().find { it.modelName == modelName } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt index b203ff6c..c58a4b08 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt @@ -6,53 +6,53 @@ import java.util.stream.Collectors object RuleTreeBuilder { - val String.escape get() = replace("$", "\\$") - - fun String.safeSubstring(from: Int, to: Int?) = when { - to == null -> "" - from >= to -> "" - from < 0 -> "" - to > length -> "" - else -> substring(from, to) - } - - @Language("kotlin") - fun getRuleExpression( - toMatch: Set, - doNotMatch: SortedSet, - result: Boolean - ): String = if (doNotMatch.size < toMatch.size) { - getRuleExpression(doNotMatch, toMatch.toSortedSet(), !result) - } else """ + val String.escape get() = replace("$", "\\$") + + fun String.safeSubstring(from: Int, to: Int?) = when { + to == null -> "" + from >= to -> "" + from < 0 -> "" + to > length -> "" + else -> substring(from, to) + } + + @Language("kotlin") + fun getRuleExpression( + toMatch: Set, + doNotMatch: SortedSet, + result: Boolean + ): String = if (doNotMatch.size < toMatch.size) { + getRuleExpression(doNotMatch, toMatch.toSortedSet(), !result) + } else """ when { ${getRules(toMatch.toSet(), doNotMatch.toSortedSet(), result).replace("\n", "\n ")} else -> ${!result} } """.trimIndent().trim() - private fun getRules( - toMatch: Set, - doNotMatch: SortedSet, - result: Boolean - ): String { - if (doNotMatch.isEmpty()) return "true -> $result\n" - val sb: StringBuilder = StringBuilder() - val remainingItems = toMatch.toMutableSet() - fun String.bestPrefix(): String { - val pfx = allowedPrefixes(setOf(this), doNotMatch).firstOrNull() ?: this - require(pfx.isNotBlank()) - //require(doNotMatch.none { it.startsWith(pfx) }) - return pfx - } - while (remainingItems.isNotEmpty()) { + private fun getRules( + toMatch: Set, + doNotMatch: SortedSet, + result: Boolean + ): String { + if (doNotMatch.isEmpty()) return "true -> $result\n" + val sb: StringBuilder = StringBuilder() + val remainingItems = toMatch.toMutableSet() + fun String.bestPrefix(): String { + val pfx = allowedPrefixes(setOf(this), doNotMatch).firstOrNull() ?: this + require(pfx.isNotBlank()) + //require(doNotMatch.none { it.startsWith(pfx) }) + return pfx + } + while (remainingItems.isNotEmpty()) { - val bestNextPrefix = bestPrefix(remainingItems.toSortedSet(), doNotMatch) + val bestNextPrefix = bestPrefix(remainingItems.toSortedSet(), doNotMatch) // val doNotMatchReversed = remainingItems.map { it.reversed() }.toSortedSet() // fun String.bestSuffix() = allowedPrefixes(setOf(this).map { it.reversed() }, doNotMatchReversed).firstOrNull()?.reversed() ?: this // val bestNextSuffix = bestNextSuffix(remainingItems, doNotMatchReversed, sortedItems) - when { + when { // bestNextSuffix != null && bestNextSuffix.second > (bestNextPrefix?.second ?: 0) -> { // val matchedItems = remainingItems.filter { it.endsWith(bestNextSuffix.first) }.toSet() // val matchedSuffixes = matchedItems.map { it.bestSuffix() }.toSet() @@ -78,83 +78,87 @@ object RuleTreeBuilder { // } // remainingItems.removeAll(matchedItems) // } - bestNextPrefix == null -> break - else -> { - val matchedItems = remainingItems.filter { it.startsWith(bestNextPrefix) }.toSet() - val matchedBlacklist = doNotMatch.filter { it.startsWith(bestNextPrefix) } - when { - matchedBlacklist.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> $result""" + "\n") - matchedItems.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> ${!result}""" + "\n") - (matchedItems + matchedBlacklist).map { it.bestPrefix() }.distinct().size < 3 -> break - else -> { - val subRules = getRuleExpression( - matchedItems.map { it.removePrefix(bestNextPrefix) }.toSet(), - matchedBlacklist.map { it.removePrefix(bestNextPrefix) }.toSortedSet(), - result - ) - sb.append( - """ + bestNextPrefix == null -> break + else -> { + val matchedItems = remainingItems.filter { it.startsWith(bestNextPrefix) }.toSet() + val matchedBlacklist = doNotMatch.filter { it.startsWith(bestNextPrefix) } + when { + matchedBlacklist.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> $result""" + "\n") + matchedItems.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> ${!result}""" + "\n") + (matchedItems + matchedBlacklist).map { it.bestPrefix() }.distinct().size < 3 -> break + else -> { + val subRules = getRuleExpression( + matchedItems.map { it.removePrefix(bestNextPrefix) }.toSet(), + matchedBlacklist.map { it.removePrefix(bestNextPrefix) }.toSortedSet(), + result + ) + sb.append( + """ path.startsWith("${bestNextPrefix.escape}") -> { val path = path.substring(${bestNextPrefix.length}) ${subRules.replace("\n", "\n ")} } """.trimIndent() + "\n" - ) + ) + } + } + remainingItems.removeAll(matchedItems) + } } - } - remainingItems.removeAll(matchedItems) } - } + remainingItems.map { it.bestPrefix() }.toSortedSet().forEach { + require(doNotMatch.none { prefix -> prefix.startsWith(it) }) + sb.append("""path.startsWith("${it.escape}") -> $result""" + "\n") + } + return sb.toString() } - remainingItems.map { it.bestPrefix() }.toSortedSet().forEach { - require(doNotMatch.none { prefix -> prefix.startsWith(it) }) - sb.append("""path.startsWith("${it.escape}") -> $result""" + "\n") + + private fun bestPrefix( + positiveSet: SortedSet, + negativeSet: SortedSet + ) = allowedPrefixes(positiveSet, negativeSet) + .parallelStream() + .flatMap { prefixExpand(listOf(it)).stream() } + .filter { it.isNotBlank() } + .map { prefix -> + val goodCnt = positiveSet.subSet(prefix, prefix + "\uFFFF").size + val badCnt = negativeSet.subSet(prefix, prefix + "\uFFFF").size + if (badCnt == 0) return@map prefix to (goodCnt - 1).toDouble() * prefix.length + //if (goodCnt == 0) return@map prefix to (badCnt - 1).toDouble() * prefix.length + val totalCnt = goodCnt + badCnt + val goodFactor = goodCnt.toDouble() / totalCnt + val badFactor = badCnt.toDouble() / totalCnt + val entropy = goodFactor * Math.log(goodFactor) + badFactor * Math.log(badFactor) + prefix to entropy + }.reduce({ a, b -> if (a.second >= b.second) a else b }).orElse(null)?.first + + fun prefixExpand(allowedPrefixes: Collection) = + allowedPrefixes.filter { allowedPrefixes.none { prefix -> prefix != it && prefix.startsWith(it) } } + .flatMap { prefixExpand(it) }.toSet() + + private fun prefixExpand(it: String) = (1..it.length).map { i -> it.substring(0, i) } + + fun allowedPrefixes( + items: Collection, + doNotMatch: SortedSet + ) = items.toList().parallelStream().map { item -> + val list = listOf( + item.safeSubstring( + 0, + longestCommonPrefix(doNotMatch.tailSet(item).firstOrNull(), item)?.length?.let { it + 1 }), + item.safeSubstring( + 0, + longestCommonPrefix(doNotMatch.headSet(item).lastOrNull(), item)?.length?.let { it + 1 }), + ) + list.maxByOrNull { it.length } ?: list.firstOrNull() + }.distinct().collect(Collectors.toSet()).filterNotNull().filter { it.isNotBlank() }.toSortedSet() + + fun longestCommonPrefix(a: String?, b: String?): String? { + if (a == null || b == null) return null + var i = 0 + while (i < a.length && i < b.length && a[i] == b[i]) i++ + return a.substring(0, i) } - return sb.toString() - } - - private fun bestPrefix( - positiveSet: SortedSet, - negativeSet: SortedSet - ) = allowedPrefixes(positiveSet, negativeSet) - .parallelStream() - .flatMap { prefixExpand(listOf(it)).stream() } - .filter { it.isNotBlank() } - .map { prefix -> - val goodCnt = positiveSet.subSet(prefix, prefix + "\uFFFF").size - val badCnt = negativeSet.subSet(prefix, prefix + "\uFFFF").size - if (badCnt == 0) return@map prefix to (goodCnt - 1).toDouble() * prefix.length - //if (goodCnt == 0) return@map prefix to (badCnt - 1).toDouble() * prefix.length - val totalCnt = goodCnt + badCnt - val goodFactor = goodCnt.toDouble() / totalCnt - val badFactor = badCnt.toDouble() / totalCnt - val entropy = goodFactor * Math.log(goodFactor) + badFactor * Math.log(badFactor) - prefix to entropy - }.reduce({ a, b -> if (a.second >= b.second) a else b }).orElse(null)?.first - - fun prefixExpand(allowedPrefixes: Collection) = - allowedPrefixes.filter { allowedPrefixes.none { prefix -> prefix != it && prefix.startsWith(it) } } - .flatMap { prefixExpand(it) }.toSet() - - private fun prefixExpand(it: String) = (1..it.length).map { i -> it.substring(0, i) } - - fun allowedPrefixes( - items: Collection, - doNotMatch: SortedSet - ) = items.toList().parallelStream().map { item -> - val list = listOf( - item.safeSubstring(0, longestCommonPrefix(doNotMatch.tailSet(item).firstOrNull(), item)?.length?.let { it + 1 }), - item.safeSubstring(0, longestCommonPrefix(doNotMatch.headSet(item).lastOrNull(), item)?.length?.let { it + 1 }), - ) - list.maxByOrNull { it.length } ?: list.firstOrNull() - }.distinct().collect(Collectors.toSet()).filterNotNull().filter { it.isNotBlank() }.toSortedSet() - - fun longestCommonPrefix(a: String?, b: String?): String? { - if (a == null || b == null) return null - var i = 0 - while (i < a.length && i < b.length && a[i] == b[i]) i++ - return a.substring(0, i) - } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt index 88e25fb3..421ea883 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt @@ -3,11 +3,11 @@ package com.simiacryptus.skyenet.core.util import java.net.URL interface Selenium : AutoCloseable { - fun save( - url: URL, - currentFilename: String?, - saveRoot: String - ) + fun save( + url: URL, + currentFilename: String?, + saveRoot: String + ) // // open fun setCookies( // driver: WebDriver, diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt index 624ef73a..efe32533 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt @@ -1,31 +1,32 @@ package com.simiacryptus.skyenet.core.util object StringSplitter { - fun split(text: String, seperators: Map): Pair { - val splitAt = seperators.entries.map { (sep, weight) -> - val splitPoint = (0 until (text.length - sep.length)).filter { i -> - text.substring(i, i+sep.length) == sep - }.map { i -> - val a = i.toDouble() / text.length - val b = 1.0 - a - i to b * Math.log(a) + a * Math.log(b) - }.maxByOrNull { it.second } - if (null == splitPoint) null - else sep to ((splitPoint.first+sep.length) to splitPoint.second / weight) - }.filterNotNull().maxByOrNull { it.second.second }?.second?.first ?: (text.length / 2) - return text.substring(0, splitAt) to text.substring(splitAt) - } + fun split(text: String, seperators: Map): Pair { + val splitAt = seperators.entries.map { (sep, weight) -> + val splitPoint = (0 until (text.length - sep.length)).filter { i -> + text.substring(i, i + sep.length) == sep + }.map { i -> + val a = i.toDouble() / text.length + val b = 1.0 - a + i to b * Math.log(a) + a * Math.log(b) + }.maxByOrNull { it.second } + if (null == splitPoint) null + else sep to ((splitPoint.first + sep.length) to splitPoint.second / weight) + }.filterNotNull().maxByOrNull { it.second.second }?.second?.first ?: (text.length / 2) + return text.substring(0, splitAt) to text.substring(splitAt) + } - @JvmStatic - fun main(args: Array) { - println( - split( - text = "This is a test. This is only a test. If this were a real emergency, you would be instructed to panic.", - seperators = mapOf( - "." to 2.0, - " " to 1.0, - ", " to 2.0, + @JvmStatic + fun main(args: Array) { + println( + split( + text = "This is a test. This is only a test. If this were a real emergency, you would be instructed to panic.", + seperators = mapOf( + "." to 2.0, + " " to 1.0, + ", " to 2.0, + ) + ).toList().joinToString("\n") ) - ).toList().joinToString("\n")) - } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt index b7697286..ecc7d076 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt @@ -15,9 +15,11 @@ interface Interpreter { @Suppress("unused") fun square(x: Int): Int = x * x } + private interface TestInterface { fun square(x: Int): Int } + @JvmStatic fun test(factory: java.util.function.Function, Interpreter>) { val testImpl = object : TestInterface { diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt index bd9dc21b..25ca0e3f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt @@ -47,9 +47,7 @@ abstract class InterpreterTestBase { } class FooBar { - fun bar(): String { - return "Foo says Hello World" - } + fun bar() = "Foo says Hello World" } @Test diff --git a/core/src/test/java/com/simiacryptus/skyenet/core/actors/ActorOptTest.kt b/core/src/test/java/com/simiacryptus/skyenet/core/actors/ActorOptTest.kt index 09387ed7..79f40f81 100644 --- a/core/src/test/java/com/simiacryptus/skyenet/core/actors/ActorOptTest.kt +++ b/core/src/test/java/com/simiacryptus/skyenet/core/actors/ActorOptTest.kt @@ -11,57 +11,61 @@ import kotlin.system.exitProcess object ActorOptTest { - private val log = LoggerFactory.getLogger(ActorOptTest::class.java) + private val log = LoggerFactory.getLogger(ActorOptTest::class.java) - @JvmStatic - fun main(args: Array) { - try { - ActorOptimization( - api = OpenAIClient( - logLevel = Level.DEBUG - ), - model = ChatModels.GPT35Turbo - ).runGeneticGenerations( - populationSize = 7, - generations = 5, - actorFactory = { SimpleActor(prompt = it, - model = ChatModels.GPT35Turbo) }, - resultMapper = { it }, - prompts = listOf( - """ + @JvmStatic + fun main(args: Array) { + try { + ActorOptimization( + api = OpenAIClient( + logLevel = Level.DEBUG + ), + model = ChatModels.GPT35Turbo + ).runGeneticGenerations( + populationSize = 7, + generations = 5, + actorFactory = { + SimpleActor( + prompt = it, + model = ChatModels.GPT35Turbo + ) + }, + resultMapper = { it }, + prompts = listOf( + """ |As the intermediary between the user and the search engine, your main task is to generate search queries based on user requests. |Please respond to each user request by providing one or more calls to the "`search('query text')`" function. |""".trimMargin(), - """ + """ |You act as a bridge between the user and the search engine by creating search queries. |Output one or more calls to "`search('query text')`" in response to each user request. |""".trimMargin().trim(), - """ + """ |You play the role of a search assistant. |Provide one or more "`search('query text')`" calls as a response to each user request. |Make sure to use single quotes around the query text. |Surround the search function call with backticks. |""".trimMargin().trim(), - ), - testCases = listOf( - ActorOptimization.TestCase( - userMessages = listOf( - "I want to buy a book.", - "A history book about Napoleon.", - ).map { it.toChatMessage() }, - expectations = listOf( - Expectation.ContainsMatch("""`search\('.*?'\)`""".toRegex(), critical = false), - Expectation.ContainsMatch("""search\(.*?\)""".toRegex(), critical = false), - Expectation.VectorMatch("Great, what kind of book are you looking for?") + ), + testCases = listOf( + ActorOptimization.TestCase( + userMessages = listOf( + "I want to buy a book.", + "A history book about Napoleon.", + ).map { it.toChatMessage() }, + expectations = listOf( + Expectation.ContainsMatch("""`search\('.*?'\)`""".toRegex(), critical = false), + Expectation.ContainsMatch("""search\(.*?\)""".toRegex(), critical = false), + Expectation.VectorMatch("Great, what kind of book are you looking for?") + ) + ) + ), ) - ) - ), - ) - } catch (e: Throwable) { - log.error("Error", e) - } finally { - exitProcess(0) + } catch (e: Throwable) { + log.error("Error", e) + } finally { + exitProcess(0) + } } - } } \ No newline at end of file diff --git a/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthenticationManagerTest.kt b/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthenticationManagerTest.kt index 02882b57..1548252f 100644 --- a/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthenticationManagerTest.kt +++ b/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthenticationManagerTest.kt @@ -3,4 +3,4 @@ package com.simiacryptus.skyenet.core.platform import AuthenticationInterfaceTest import com.simiacryptus.skyenet.core.platform.file.AuthenticationManager -class AuthenticationManagerTest : AuthenticationInterfaceTest(AuthenticationManager()) {} \ No newline at end of file +class AuthenticationManagerTest : AuthenticationInterfaceTest(AuthenticationManager()) \ No newline at end of file diff --git a/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthorizationManagerTest.kt b/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthorizationManagerTest.kt index 5cdff9db..ad47a3b0 100644 --- a/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthorizationManagerTest.kt +++ b/core/src/test/kotlin/com/simiacryptus/skyenet/core/platform/AuthorizationManagerTest.kt @@ -3,4 +3,4 @@ package com.simiacryptus.skyenet.core.platform import AuthorizationInterfaceTest import com.simiacryptus.skyenet.core.platform.file.AuthorizationManager -class AuthorizationManagerTest : AuthorizationInterfaceTest(AuthorizationManager()) {} \ No newline at end of file +class AuthorizationManagerTest : AuthorizationInterfaceTest(AuthorizationManager()) \ No newline at end of file diff --git a/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt b/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt index 60c7c80d..639a7085 100644 --- a/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt +++ b/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt @@ -7,23 +7,23 @@ import org.junit.jupiter.api.Test class RuleTreeBuilderTest { - @Test - fun testEscape() { - Assertions.assertEquals("\\$100", "$100".escape) - Assertions.assertEquals("NoSpecialCharacters", "NoSpecialCharacters".escape) - Assertions.assertEquals("\\$\\$", "$$".escape) - } + @Test + fun testEscape() { + Assertions.assertEquals("\\$100", "$100".escape) + Assertions.assertEquals("NoSpecialCharacters", "NoSpecialCharacters".escape) + Assertions.assertEquals("\\$\\$", "$$".escape) + } - @Test - fun testSafeSubstring() { - val testString = "HelloWorld" - Assertions.assertEquals("", testString.safeSubstring(-1, 5)) - Assertions.assertEquals("", testString.safeSubstring(0, 11)) - Assertions.assertEquals("", testString.safeSubstring(5, 5)) - Assertions.assertEquals("", testString.safeSubstring(0, null)) - Assertions.assertEquals("Hello", testString.safeSubstring(0, 5)) - Assertions.assertEquals("World", testString.safeSubstring(5, 10)) - } + @Test + fun testSafeSubstring() { + val testString = "HelloWorld" + Assertions.assertEquals("", testString.safeSubstring(-1, 5)) + Assertions.assertEquals("", testString.safeSubstring(0, 11)) + Assertions.assertEquals("", testString.safeSubstring(5, 5)) + Assertions.assertEquals("", testString.safeSubstring(0, null)) + Assertions.assertEquals("Hello", testString.safeSubstring(0, 5)) + Assertions.assertEquals("World", testString.safeSubstring(5, 10)) + } // @Test // fun testBestNextPrefix() { @@ -44,27 +44,27 @@ class RuleTreeBuilderTest { // Assertions.assertEquals("e", bestNextSuffix?.first) // } - @Test - fun testPrefixExpand() { - val allowedPrefixes = setOf("app", "ban") - val expandedPrefixes = RuleTreeBuilder.prefixExpand(allowedPrefixes) - Assertions.assertTrue(expandedPrefixes.containsAll(setOf("a", "ap", "app", "b", "ba", "ban"))) - } + @Test + fun testPrefixExpand() { + val allowedPrefixes = setOf("app", "ban") + val expandedPrefixes = RuleTreeBuilder.prefixExpand(allowedPrefixes) + Assertions.assertTrue(expandedPrefixes.containsAll(setOf("a", "ap", "app", "b", "ba", "ban"))) + } - @Test - fun testAllowedPrefixes() { - val items = listOf("apple", "apricot") - val doNotMatch = sortedSetOf("application", "appetizer") - val allowedPrefixes = RuleTreeBuilder.allowedPrefixes(items, doNotMatch) - Assertions.assertEquals(sortedSetOf("apple", "apr"), allowedPrefixes) - } + @Test + fun testAllowedPrefixes() { + val items = listOf("apple", "apricot") + val doNotMatch = sortedSetOf("application", "appetizer") + val allowedPrefixes = RuleTreeBuilder.allowedPrefixes(items, doNotMatch) + Assertions.assertEquals(sortedSetOf("apple", "apr"), allowedPrefixes) + } - @Test - fun testLongestCommonPrefix() { - Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix(null, "test")) - Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix("test", null)) - Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("", "test")) - Assertions.assertEquals("te", RuleTreeBuilder.longestCommonPrefix("test", "teapot")) - Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("test", "best")) - } + @Test + fun testLongestCommonPrefix() { + Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix(null, "test")) + Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix("test", null)) + Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("", "test")) + Assertions.assertEquals("te", RuleTreeBuilder.longestCommonPrefix("test", "teapot")) + Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("test", "best")) + } } \ No newline at end of file diff --git a/core_user_documentation.md b/docs/core_user_documentation.md similarity index 100% rename from core_user_documentation.md rename to docs/core_user_documentation.md diff --git a/webui_documentation.md b/docs/webui_documentation.md similarity index 99% rename from webui_documentation.md rename to docs/webui_documentation.md index b4320b75..111e148d 100644 --- a/webui_documentation.md +++ b/docs/webui_documentation.md @@ -1391,7 +1391,7 @@ of strings. If no command is specified, it defaults to `bash`. #### getLanguage ```kotlin -final override fun getLanguage(): String + override fun getLanguage(): String ``` Returns the programming language of the code to be interpreted. The language is specified in the `defs` map. If not @@ -4040,7 +4040,7 @@ Initializes settings for a session by returning an instance of the `Settings` da at construction. ```kotlin -override fun initSettings(session: Session): T? = Settings(actor = actor) as T +override fun initSettings(session: Session): T = Settings(actor = actor) as T ``` ##### userMessage diff --git a/gradle.properties b/gradle.properties index d13aca04..95bf486e 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,4 +1,4 @@ # Gradle Releases -> https://github.com/gradle/gradle/releases libraryGroup = com.simiacryptus.skyenet -libraryVersion = 1.0.62 +libraryVersion = 1.0.63 gradleVersion = 7.6.1 diff --git a/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt b/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt index af385f29..a482d8ce 100644 --- a/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt +++ b/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt @@ -18,9 +18,9 @@ open class GroovyInterpreter(private val defs: java.util.Map) : } } - override fun getLanguage(): String { - return "groovy" - } + override fun getLanguage(): String { + return "groovy" + } override fun getSymbols() = defs as Map diff --git a/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt b/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt index 6dcad720..07376453 100644 --- a/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt +++ b/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt @@ -5,7 +5,8 @@ package com.simiacryptus.skyenet.groovy import com.simiacryptus.skyenet.interpreter.InterpreterTestBase class GroovyInterpreterTest : InterpreterTestBase() { - override fun newInterpreter(map: Map) = GroovyInterpreter(map.map { it.key to it.value as Object }.toMap().toJavaMap()) + override fun newInterpreter(map: Map) = + GroovyInterpreter(map.map { it.key to it.value as Object }.toMap().toJavaMap()) } diff --git a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt index 0f8f08f9..d0036699 100644 --- a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt +++ b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt @@ -20,130 +20,130 @@ import kotlin.script.experimental.jvm.util.scriptCompilationClasspathFromContext import kotlin.script.experimental.jvmhost.jsr223.KotlinJsr223ScriptEngineImpl open class KotlinInterpreter( - val defs: Map = mapOf(), + val defs: Map = mapOf(), ) : Interpreter { - final override fun getLanguage(): String = "Kotlin" - override fun getSymbols() = defs + final override fun getLanguage(): String = "Kotlin" + override fun getSymbols() = defs - open val scriptEngine: KotlinJsr223JvmScriptEngineBase - get() = object : KotlinJsr223JvmScriptEngineFactoryBase() { - override fun getScriptEngine() = KotlinJsr223ScriptEngineImpl( - this, - KotlinJsr223DefaultScriptCompilationConfiguration.with { - classLoader?.also { classLoader -> - jvm { - updateClasspath( - scriptCompilationClasspathFromContext( - classLoader = classLoader, - wholeClasspath = true, - unpackJarCollections = false + open val scriptEngine: KotlinJsr223JvmScriptEngineBase + get() = object : KotlinJsr223JvmScriptEngineFactoryBase() { + override fun getScriptEngine() = KotlinJsr223ScriptEngineImpl( + this, + KotlinJsr223DefaultScriptCompilationConfiguration.with { + classLoader?.also { classLoader -> + jvm { + updateClasspath( + scriptCompilationClasspathFromContext( + classLoader = classLoader, + wholeClasspath = true, + unpackJarCollections = false + ) + ) + } + } + }, + KotlinJsr223DefaultScriptEvaluationConfiguration.with { + this.enableScriptsInstancesSharing() + } + ) { + ScriptArgsWithTypes( + arrayOf(), + arrayOf() ) - ) + }.apply { + getBindings(ScriptContext.ENGINE_SCOPE).putAll(getSymbols()) } - } - }, - KotlinJsr223DefaultScriptEvaluationConfiguration.with { - this.enableScriptsInstancesSharing() - } - ) { - ScriptArgsWithTypes( - arrayOf(), - arrayOf() - ) - }.apply { - getBindings(ScriptContext.ENGINE_SCOPE).putAll(getSymbols()) - } - }.scriptEngine + }.scriptEngine - override fun validate(code: String): Throwable? { - val wrappedCode = wrapCode(code) - return try { - scriptEngine.compile(wrappedCode) - null - } catch (ex: ScriptException) { - wrapException(ex, wrappedCode, code) - } catch (ex: Throwable) { - CodingActor.FailedToImplementException( - cause = ex, - language = "Kotlin", - code = code, - ) + override fun validate(code: String): Throwable? { + val wrappedCode = wrapCode(code) + return try { + scriptEngine.compile(wrappedCode) + null + } catch (ex: ScriptException) { + wrapException(ex, wrappedCode, code) + } catch (ex: Throwable) { + CodingActor.FailedToImplementException( + cause = ex, + language = "Kotlin", + code = code, + ) + } } - } - override fun run(code: String): Any? { - val wrappedCode = wrapCode(code) - log.debug( - """ + override fun run(code: String): Any? { + val wrappedCode = wrapCode(code) + log.debug( + """ |Running: | ${wrappedCode.trimIndent().replace("\n", "\n\t")} |""".trimMargin().trim() - ) - val bindings: Bindings? - val compile: CompiledScript - val scriptEngine: KotlinJsr223JvmScriptEngineBase - try { - scriptEngine = this.scriptEngine - compile = scriptEngine.compile(wrappedCode) - bindings = scriptEngine.getBindings(ScriptContext.ENGINE_SCOPE) - return kotlinx.coroutines.runBlocking { compile.eval(bindings) } - } catch (ex: ScriptException) { - throw wrapException(ex, wrappedCode, code) - } catch (ex: Throwable) { - throw CodingActor.FailedToImplementException( - cause = ex, - language = "Kotlin", - code = code, - ) + ) + val bindings: Bindings? + val compile: CompiledScript + val scriptEngine: KotlinJsr223JvmScriptEngineBase + try { + scriptEngine = this.scriptEngine + compile = scriptEngine.compile(wrappedCode) + bindings = scriptEngine.getBindings(ScriptContext.ENGINE_SCOPE) + return kotlinx.coroutines.runBlocking { compile.eval(bindings) } + } catch (ex: ScriptException) { + throw wrapException(ex, wrappedCode, code) + } catch (ex: Throwable) { + throw CodingActor.FailedToImplementException( + cause = ex, + language = "Kotlin", + code = code, + ) + } } - } - protected open fun wrapException( - cause: ScriptException, - wrappedCode: String, - code: String - ): CodingActor.FailedToImplementException { - var lineNumber = cause.lineNumber - var column = cause.columnNumber - if (lineNumber == -1 && column == -1) { - val match = Regex("\\(.*:(\\d+):(\\d+)\\)").find(cause.message ?: "") - if (match != null) { - lineNumber = match.groupValues[1].toInt() - column = match.groupValues[2].toInt() - } + protected open fun wrapException( + cause: ScriptException, + wrappedCode: String, + code: String + ): CodingActor.FailedToImplementException { + var lineNumber = cause.lineNumber + var column = cause.columnNumber + if (lineNumber == -1 && column == -1) { + val match = Regex("\\(.*:(\\d+):(\\d+)\\)").find(cause.message ?: "") + if (match != null) { + lineNumber = match.groupValues[1].toInt() + column = match.groupValues[2].toInt() + } + } + return CodingActor.FailedToImplementException( + cause = cause, + message = errorMessage( + code = wrappedCode, + line = lineNumber, + column = column, + message = cause.message ?: "" + ), + language = "Kotlin", + code = code, + ) } - return CodingActor.FailedToImplementException( - cause = cause, - message = errorMessage( - code = wrappedCode, - line = lineNumber, - column = column, - message = cause.message ?: "" - ), - language = "Kotlin", - code = code, - ) - } - override fun wrapCode(code: String): String { - val out = ArrayList() - val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") } - out.addAll(imports) - out.addAll(otherCode) - return out.joinToString("\n") - } + override fun wrapCode(code: String): String { + val out = ArrayList() + val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") } + out.addAll(imports) + out.addAll(otherCode) + return out.joinToString("\n") + } - companion object { - private val log = LoggerFactory.getLogger(KotlinInterpreter::class.java) + companion object { + private val log = LoggerFactory.getLogger(KotlinInterpreter::class.java) - fun errorMessage( - code: String, - line: Int, - column: Int, - message: String - ) = """ + fun errorMessage( + code: String, + line: Int, + column: Int, + message: String + ) = """ |```text |$message at line ${line} column ${column} | ${if (line < 0) "" else code.split("\n")[line - 1]} @@ -151,8 +151,8 @@ open class KotlinInterpreter( |``` """.trimMargin().trim() - // TODO: Make this threadlocal with wrapper methods - var classLoader: ClassLoader? = KotlinInterpreter::class.java.classLoader + // TODO: Make this threadlocal with wrapper methods + var classLoader: ClassLoader? = KotlinInterpreter::class.java.classLoader - } + } } \ No newline at end of file diff --git a/webui/build.gradle.kts b/webui/build.gradle.kts index 43a22f83..930a3896 100644 --- a/webui/build.gradle.kts +++ b/webui/build.gradle.kts @@ -35,7 +35,7 @@ val jetty_version = "11.0.18" val jackson_version = "2.15.3" dependencies { - implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.51") + implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.52") implementation(project(":core")) implementation(project(":kotlin")) diff --git a/webui/src/compiled_documentation.md b/webui/src/compiled_documentation.md index c1c05642..1dc186b4 100644 --- a/webui/src/compiled_documentation.md +++ b/webui/src/compiled_documentation.md @@ -2895,7 +2895,7 @@ context. This is typically done in the application's initialization code or thro @WebServlet("/welcome") public class CustomWelcomeServlet extends WelcomeServlet { public CustomWelcomeServlet() { - super(new ApplicationDirectory(...)); + super(new ApplicationDirectory(...)) } @Override diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyDiffLinks.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyDiffLinks.kt index 13ef757a..e127f0ec 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyDiffLinks.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyDiffLinks.kt @@ -8,7 +8,7 @@ import com.simiacryptus.skyenet.webui.session.SocketManagerBase import com.simiacryptus.skyenet.webui.util.MarkdownUtil.renderMarkdown fun SocketManagerBase.addApplyDiffLinks( - code: StringBuilder, + code: () -> String, response: String, handle: (String) -> Unit, task: SessionTask, @@ -23,7 +23,7 @@ fun SocketManagerBase.addApplyDiffLinks( var reverseHrefLink: StringBuilder? = null hrefLink = applydiffTask.complete(hrefLink("Apply Diff", classname = "href-link cmd-button") { try { - val newCode = IterativePatchUtil.patch(code.toString(), diffVal).replace("\r", "") + val newCode = IterativePatchUtil.patch(code(), diffVal).replace("\r", "") handle(newCode) reverseHrefLink?.clear() hrefLink.set("""

Diff Applied
""") @@ -32,21 +32,21 @@ fun SocketManagerBase.addApplyDiffLinks( task.error(ui, e) } })!! - val patch = IterativePatchUtil.patch(code.toString(), diffVal).replace("\r", "") + val patch = IterativePatchUtil.patch(code(), diffVal).replace("\r", "") val test1 = DiffUtil.formatDiff( DiffUtil.generateDiff( - code.toString().replace("\r", "").lines(), + code().replace("\r", "").lines(), patch.lines() ) ) val patchRev = IterativePatchUtil.patch( - code.lines().reversed().joinToString("\n"), + code().lines().reversed().joinToString("\n"), diffVal.lines().reversed().joinToString("\n") ).replace("\r", "") if (patchRev != patch) { reverseHrefLink = applydiffTask.complete(hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { try { - val reversedCode = code.lines().reversed().joinToString("\n") + val reversedCode = code().lines().reversed().joinToString("\n") val reversedDiff = diffVal.lines().reversed().joinToString("\n") val newReversedCode = IterativePatchUtil.patch(reversedCode, reversedDiff).replace("\r", "") val newCode = newReversedCode.lines().reversed().joinToString("\n") @@ -61,7 +61,7 @@ fun SocketManagerBase.addApplyDiffLinks( } val test2 = DiffUtil.formatDiff( DiffUtil.generateDiff( - code.lines(), + code().lines(), patchRev.lines().reversed() ) ) diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyFileDiffLinks.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyFileDiffLinks.kt index d4e5c8f2..a9a60d90 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyFileDiffLinks.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddApplyFileDiffLinks.kt @@ -1,5 +1,6 @@ package com.github.simiacryptus.diff +import com.github.simiacryptus.diff.IterativePatchUtil.patch import com.simiacryptus.skyenet.AgentPatterns import com.simiacryptus.skyenet.set import com.simiacryptus.skyenet.webui.application.ApplicationInterface @@ -12,9 +13,9 @@ import kotlin.io.path.readText fun SocketManagerBase.addApplyFileDiffLinks( root: Path, - code: Map, + code: () -> Map, response: String, - handle: (Map) -> Unit, + handle: (Map) -> Unit, ui: ApplicationInterface, ): String { val headerPattern = """(?s)(? val header = headers.lastOrNull { it.first.endInclusive < diffBlock.first.start } - val filename = header?.second ?: "Unknown" + val filename = resolve(root, header?.second ?: "Unknown") val diffVal = diffBlock.second - val newValue = renderDiffBlock(root, filename, code, diffVal, handle, ui) + val newValue = renderDiffBlock(root, filename, code(), diffVal, handle, ui) markdown.replace("```diff\n$diffVal\n```", newValue) } val withSaveLinks = codeblocks.fold(withPatchLinks) { markdown, codeBlock -> val header = headers.lastOrNull { it.first.endInclusive < codeBlock.first.start } - val filename = header?.second ?: "Unknown" - val filepath = path(root, filename) - val prevCode = load(filepath, root, code) + val filename = resolve(root, header?.second ?: "Unknown") + val filepath: Path? = path(root, filename) + val prevCode = load(filepath) val codeLang = codeBlock.second.groupValues[1] val codeValue = codeBlock.second.groupValues[2] val commandTask = ui.newTask(false) lateinit var hrefLink: StringBuilder hrefLink = commandTask.complete(hrefLink("Save File", classname = "href-link cmd-button") { try { - handle( - mapOf( - filename to codeValue - ) - ) + save(filepath, codeValue) hrefLink.set("""
Saved ${filename}
""") commandTask.complete() + handle(mapOf(File(filename).toPath() to codeValue)) //task.complete("""
Saved ${filename}
""") } catch (e: Throwable) { commandTask.error(null, e) @@ -61,25 +59,23 @@ fun SocketManagerBase.addApplyFileDiffLinks( })!! val codeblockRaw = """ - ```${codeLang} - ${codeValue} - ``` - """.trimIndent() + |```${codeLang} + |${codeValue} + |``` + """.trimMargin() markdown.replace( codeblockRaw, AgentPatterns.displayMapInTabs( mapOf( "New" to MarkdownUtil.renderMarkdown(codeblockRaw, ui = ui), - "Old" to MarkdownUtil.renderMarkdown( - """ - |```${codeLang} - |${prevCode} - |``` - """.trimMargin(), ui = ui + "Old" to MarkdownUtil.renderMarkdown(""" + |```${codeLang} + |${prevCode} + |``` + """.trimMargin(), ui = ui ), - "Patch" to MarkdownUtil.renderMarkdown( - """ - |```diff - |${ + "Patch" to MarkdownUtil.renderMarkdown(""" + |```diff + |${ DiffUtil.formatDiff( DiffUtil.generateDiff( prevCode.lines(), @@ -87,8 +83,8 @@ fun SocketManagerBase.addApplyFileDiffLinks( ) ) } - |``` - """.trimMargin(), ui = ui + |``` + """.trimMargin(), ui = ui ), ) ) + "\n" + commandTask.placeholder @@ -97,18 +93,41 @@ fun SocketManagerBase.addApplyFileDiffLinks( return withSaveLinks } +fun resolve(root: Path, filename: String): String { + var filepath = path(root, filename) + if (filepath?.toFile()?.exists() == false) filepath = null + if (null != filepath) return filepath.toString() + val files = root.toFile().recurseFiles().filter { it.name == filename.split('/', '\\').last() } + if (files.size == 1) { + filepath = files.first().toPath() + } + return root.relativize(filepath).toString() +} + +fun File.recurseFiles(): List { + val files = mutableListOf() + if (isDirectory) { + listFiles()?.forEach { + files.addAll(it.recurseFiles()) + } + } else { + files.add(this) + } + return files +} + private fun SocketManagerBase.renderDiffBlock( root: Path, filename: String, - code: Map, + code: Map, diffVal: String, - handle: (Map) -> Unit, + handle: (Map) -> Unit, ui: ApplicationInterface ): String { val filepath = path(root, filename) - val prevCode = load(filepath, root, code) - val newCode = IterativePatchUtil.patch(prevCode, diffVal) + val prevCode = load(filepath) + val newCode = patch(prevCode, diffVal) val echoDiff = try { DiffUtil.formatDiff( DiffUtil.generateDiff( @@ -123,21 +142,16 @@ private fun SocketManagerBase.renderDiffBlock( val applydiffTask = ui.newTask(false) lateinit var hrefLink: StringBuilder lateinit var reverseHrefLink: StringBuilder + val relativize = try { + root.relativize(filepath) + } catch (e: Throwable) { + filepath + } hrefLink = applydiffTask.complete(hrefLink("Apply Diff", classname = "href-link cmd-button") { try { - val relativize = try { - root.relativize(filepath) - } catch (e: Throwable) { - filepath - } - handle( - mapOf( - relativize.toString() to IterativePatchUtil.patch( - prevCode, - diffVal - ) - ) - ) + val newCode = patch(prevCode, diffVal) + handle(mapOf(relativize!! to newCode)) + filepath?.toFile()?.writeText(newCode, Charsets.UTF_8) ?: log.warn("File not found: $filepath") reverseHrefLink.clear() hrefLink.set("""
Diff Applied
""") applydiffTask.complete() @@ -147,14 +161,12 @@ private fun SocketManagerBase.renderDiffBlock( })!! reverseHrefLink = applydiffTask.complete(hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { try { - val reversedCodeMap = code.mapValues { (_, v) -> v.lines().reversed().joinToString("\n") } - val reversedDiff = diffVal.lines().reversed().joinToString("\n") - val newReversedCodeMap = reversedCodeMap.mapValues { (file, prevCode) -> - if (filename == file) { - IterativePatchUtil.patch(prevCode, reversedDiff).lines().reversed().joinToString("\n") - } else prevCode - } - handle(newReversedCodeMap) + val newCode = patch( + prevCode.lines().reversed().joinToString("\n"), + diffVal.lines().reversed().joinToString("\n") + ).lines().reversed().joinToString("\n") + handle(mapOf(relativize!! to newCode)) + filepath?.toFile()?.writeText(newCode, Charsets.UTF_8) ?: log.warn("File not found: $filepath") hrefLink.clear() reverseHrefLink.set("""
Diff Applied (Bottom to Top)
""") applydiffTask.complete() @@ -162,10 +174,10 @@ private fun SocketManagerBase.renderDiffBlock( applydiffTask.error(null, e) } })!! - val diffTask = ui?.newTask(root = false) - val prevCodeTask = ui?.newTask(root = false) - val newCodeTask = ui?.newTask(root = false) - val patchTask = ui?.newTask(root = false) + val diffTask = ui.newTask(root = false) + val prevCodeTask = ui.newTask(root = false) + val newCodeTask = ui.newTask(root = false) + val patchTask = ui.newTask(root = false) val inTabs = AgentPatterns.displayMapInTabs( mapOf( "Diff" to (diffTask?.placeholder ?: ""), @@ -196,19 +208,10 @@ private fun SocketManagerBase.renderDiffBlock( private fun load( - filepath: Path?, - root: Path, - code: Map + filepath: Path? ) = try { if (true != filepath?.toFile()?.exists()) { - log.warn( - """ - |File not found: $filepath - |Root: ${root.toAbsolutePath()} - |Files: - |${code.keys.joinToString("\n") { "* $it" }} - """.trimMargin() - ) + log.warn("""File not found: $filepath""".trimMargin()) "" } else { filepath.readText(Charsets.UTF_8) @@ -218,6 +221,19 @@ private fun load( "" } +private fun save( + filepath: Path?, + code: String +) { + try { + if (null != filepath) { + filepath.toFile().writeText(code, Charsets.UTF_8) + } + } catch (e: Throwable) { + log.error("Error writing file: $filepath", e) + } +} + private fun path(root: Path, filename: String): Path? { val filepath = try { findFile(root, filename) ?: root.resolve(filename) @@ -227,7 +243,7 @@ private fun path(root: Path, filename: String): Path? { root.resolve(filename) } catch (e: Throwable) { log.error("Error resolving file: $filename", e) - File(filename).toPath() + null } } return filepath diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddSaveLinks.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddSaveLinks.kt index 15df0408..18b4cc2f 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/AddSaveLinks.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/AddSaveLinks.kt @@ -4,32 +4,36 @@ import com.simiacryptus.skyenet.set import com.simiacryptus.skyenet.webui.application.ApplicationInterface import com.simiacryptus.skyenet.webui.session.SessionTask import com.simiacryptus.skyenet.webui.session.SocketManagerBase +import java.io.File +import java.nio.file.Path fun SocketManagerBase.addSaveLinks( - response: String, - task: SessionTask, - ui: ApplicationInterface, - handle: (String, String) -> Unit, + response: String, + task: SessionTask, + ui: ApplicationInterface, + handle: (Path, String) -> Unit, ): String { - val diffPattern = - """(?s)(? - val filename = diffBlock.groupValues[1] - val codeValue = diffBlock.groupValues[2] - val commandTask = ui.newTask(false) - lateinit var hrefLink: StringBuilder - hrefLink = commandTask.complete(hrefLink("Save File", classname = "href-link cmd-button") { - try { - handle(filename, codeValue) - hrefLink.set("""
Saved ${filename}
""") - commandTask.complete() - //task.complete("""
Saved ${filename}
""") - } catch (e: Throwable) { - task.error(null, e) - } - })!! - markdown.replace(codeValue + "```", codeValue?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ } + "```\n" + commandTask.placeholder) - } - return withLinks + val diffPattern = + """(?s)(? + val filename = diffBlock.groupValues[1] + val codeValue = diffBlock.groupValues[2] + val commandTask = ui.newTask(false) + lateinit var hrefLink: StringBuilder + hrefLink = commandTask.complete(hrefLink("Save File", classname = "href-link cmd-button") { + try { + handle(File(filename).toPath(), codeValue) + hrefLink.set("""
Saved ${filename}
""") + commandTask.complete() + //task.complete("""
Saved ${filename}
""") + } catch (e: Throwable) { + task.error(null, e) + } + })!! + markdown.replace( + codeValue + "```", + codeValue.let { /*escapeHtml4*/(it)/*.indent(" ")*/ } + "```\n" + commandTask.placeholder) + } + return withLinks } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/ApxPatchUtil.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/ApxPatchUtil.kt index 15601617..0db758f7 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/ApxPatchUtil.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/ApxPatchUtil.kt @@ -1,127 +1,124 @@ package com.github.simiacryptus.diff import org.apache.commons.text.similarity.LevenshteinDistance -import org.slf4j.LoggerFactory object ApxPatchUtil { - fun patch(source: String, patch: String): String { - val sourceLines = source.lines() - val patchLines = patch.lines() + fun patch(source: String, patch: String): String { + val sourceLines = source.lines() + val patchLines = patch.lines() - // This will hold the final result - val result = mutableListOf() + // This will hold the final result + val result = mutableListOf() - // This will keep track of the current line in the source file - var sourceIndex = 0 + // This will keep track of the current line in the source file + var sourceIndex = 0 - // Process each line in the patch - for (patchLine in patchLines.map { it.trim() }) { - when { - // If the line starts with "---" or "+++", it's a file indicator line, skip it - patchLine.startsWith("---") || patchLine.startsWith("+++") -> continue + // Process each line in the patch + for (patchLine in patchLines.map { it.trim() }) { + when { + // If the line starts with "---" or "+++", it's a file indicator line, skip it + patchLine.startsWith("---") || patchLine.startsWith("+++") -> continue - // If the line starts with "@@", it's a hunk header - patchLine.startsWith("@@") -> continue + // If the line starts with "@@", it's a hunk header + patchLine.startsWith("@@") -> continue - // If the line starts with "-", it's a deletion, skip the corresponding source line but otherwise treat it as a context line - patchLine.startsWith("-") -> { - sourceIndex = onDelete(patchLine, sourceIndex, sourceLines, result) - } + // If the line starts with "-", it's a deletion, skip the corresponding source line but otherwise treat it as a context line + patchLine.startsWith("-") -> { + sourceIndex = onDelete(patchLine, sourceIndex, sourceLines, result) + } - // If the line starts with "+", it's an addition, add it to the result - patchLine.startsWith("+") -> { - result.add(patchLine.substring(1)) - } + // If the line starts with "+", it's an addition, add it to the result + patchLine.startsWith("+") -> { + result.add(patchLine.substring(1)) + } + + // \d+\: ___ is a line number, strip it + patchLine.matches(Regex("\\d+:.*")) -> { + sourceIndex = onContextLine(patchLine.substringAfter(":"), sourceIndex, sourceLines, result) + } - // \d+\: ___ is a line number, strip it - patchLine.matches(Regex("\\d+:.*")) -> { - sourceIndex = onContextLine(patchLine.substringAfter(":"), sourceIndex, sourceLines, result) + // it's a context line, advance the source cursor + else -> { + sourceIndex = onContextLine(patchLine, sourceIndex, sourceLines, result) + } + } } - // it's a context line, advance the source cursor - else -> { - sourceIndex = onContextLine(patchLine, sourceIndex, sourceLines, result) + // Append any remaining lines from the source file + while (sourceIndex < sourceLines.size) { + result.add(sourceLines[sourceIndex]) + sourceIndex++ } - } - } - // Append any remaining lines from the source file - while (sourceIndex < sourceLines.size) { - result.add(sourceLines[sourceIndex]) - sourceIndex++ + return result.joinToString("\n") } - return result.joinToString("\n") - } - - private fun onDelete( - patchLine: String, - sourceIndex: Int, - sourceLines: List, - result: MutableList - ): Int { - var sourceIndex1 = sourceIndex - val delLine = patchLine.substring(1) - val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, delLine) - if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { - val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch) - result.addAll(contextChunk) - sourceIndex1 = sourceIndexSearch + 1 - } else { - println("Deletion line not found in source file: $delLine") - // Ignore + private fun onDelete( + patchLine: String, + sourceIndex: Int, + sourceLines: List, + result: MutableList + ): Int { + var sourceIndex1 = sourceIndex + val delLine = patchLine.substring(1) + val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, delLine) + if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { + val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch) + result.addAll(contextChunk) + sourceIndex1 = sourceIndexSearch + 1 + } else { + println("Deletion line not found in source file: $delLine") + // Ignore + } + return sourceIndex1 } - return sourceIndex1 - } - - private fun onContextLine( - patchLine: String, - sourceIndex: Int, - sourceLines: List, - result: MutableList - ): Int { - var sourceIndex1 = sourceIndex - val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, patchLine) - if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { - val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch + 1) - result.addAll(contextChunk) - sourceIndex1 = sourceIndexSearch + 1 - } else { - println("Context line not found in source file: $patchLine") - // Ignore + + private fun onContextLine( + patchLine: String, + sourceIndex: Int, + sourceLines: List, + result: MutableList + ): Int { + var sourceIndex1 = sourceIndex + val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, patchLine) + if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { + val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch + 1) + result.addAll(contextChunk) + sourceIndex1 = sourceIndexSearch + 1 + } else { + println("Context line not found in source file: $patchLine") + // Ignore + } + return sourceIndex1 } - return sourceIndex1 - } - - private fun lookAheadFor( - sourceIndex: Int, - sourceLines: List, - patchLine: String - ): Int { - var sourceIndexSearch = sourceIndex - while (sourceIndexSearch < sourceLines.size) { - if (lineMatches(patchLine, sourceLines[sourceIndexSearch++])) return sourceIndexSearch - 1 + + private fun lookAheadFor( + sourceIndex: Int, + sourceLines: List, + patchLine: String + ): Int { + var sourceIndexSearch = sourceIndex + while (sourceIndexSearch < sourceLines.size) { + if (lineMatches(patchLine, sourceLines[sourceIndexSearch++])) return sourceIndexSearch - 1 + } + return -1 } - return -1 - } - - private fun lineMatches( - a: String, - b: String, - factor: Double = 0.1, - ): Boolean { - val threshold = (Math.max(a.trim().length, b.trim().length) * factor).toInt() - val levenshteinDistance = LevenshteinDistance(threshold+1) - val dist = levenshteinDistance.apply(a.trim(), b.trim()) - return if (dist >= 0) { - dist <= threshold - } else { - false + + private fun lineMatches( + a: String, + b: String, + factor: Double = 0.1, + ): Boolean { + val threshold = (Math.max(a.trim().length, b.trim().length) * factor).toInt() + val levenshteinDistance = LevenshteinDistance(threshold + 1) + val dist = levenshteinDistance.apply(a.trim(), b.trim()) + return if (dist >= 0) { + dist <= threshold + } else { + false + } } - } } -private val log = LoggerFactory.getLogger(ApxPatchUtil::class.java) - diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffMatchPatch.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffMatchPatch.kt index 1c38266a..638e4bd0 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffMatchPatch.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffMatchPatch.kt @@ -22,2400 +22,2400 @@ import kotlin.math.min * Also contains the behaviour settings. */ open class DiffMatchPatch { - // Defaults. - // Set these on your diff_match_patch instance to override the defaults. - /** - * Number of seconds to map a diff before giving up (0 for infinity). - */ - var Diff_Timeout: Float = 1.0f - - /** - * Cost of an empty edit operation in terms of edit characters. - */ - private var Diff_EditCost: Short = 4 - - /** - * At what point is no match declared (0.0 = perfection, 1.0 = very loose). - */ - private var Match_Threshold: Float = 0.5f - - /** - * How far to search for a match (0 = exact location, 1000+ = broad match). - * A match this many characters away from the expected location will add - * 1.0 to the score (0.0 is a perfect match). - */ - private var Match_Distance: Int = 1000 - - /** - * When deleting a large block of text (over ~64 characters), how close do - * the contents have to be to match the expected contents. (0.0 = perfection, - * 1.0 = very loose). Note that Match_Threshold controls how closely the - * end points of a delete need to match. - */ - private var Patch_DeleteThreshold: Float = 0.5f - - /** - * Chunk size for context length. - */ - private var Patch_Margin: Short = 4 - - /** - * The number of bits in an int. - */ - private val Match_MaxBits: Short = 32 - - /** - * Internal class for returning results from diff_linesToChars(). - * Other less paranoid languages just use a three-element array. - */ - protected class LinesToCharsResult( - var chars1: String, var chars2: String, - var lineArray: List - ) - - - // DIFF FUNCTIONS - /** - * The data structure representing a diff is a Linked list of Diff objects: - * {Diff(Operation.DELETE, "Hello"), Diff(Operation.INSERT, "Goodbye"), - * Diff(Operation.EQUAL, " world.")} - * which means: delete "Hello", add "Goodbye" and keep " world." - */ - enum class Operation { - DELETE, INSERT, EQUAL - } - - /** - * Find the differences between two texts. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @return Linked List of Diff objects. - */ - /** - * Find the differences between two texts. - * Run a faster, slightly less optimal diff. - * This method allows the 'checklines' of diff_main() to be optional. - * Most of the time checklines is wanted, so default to true. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @return Linked List of Diff objects. - */ - @JvmOverloads - fun diff_main(text1: String?, text2: String?, checklines: Boolean = true): LinkedList { - // Set a deadline by which time the diff must be complete. - val deadline: Long - if (Diff_Timeout <= 0) { - deadline = Long.MAX_VALUE - } else { - deadline = System.currentTimeMillis() + (Diff_Timeout * 1000).toLong() - } - return diff_main(text1, text2, checklines, deadline) - } - - /** - * Find the differences between two texts. Simplifies the problem by - * stripping any common prefix or suffix off the texts before diffing. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @param deadline Time when the diff should be complete by. Used - * internally for recursive calls. Users should set DiffTimeout instead. - * @return Linked List of Diff objects. - */ - fun diff_main(text1: String?, text2: String?, checklines: Boolean, deadline: Long): LinkedList { - // Check for null inputs. - var text1 = text1 - var text2 = text2 - if (text1 == null || text2 == null) { - throw IllegalArgumentException("Null inputs. (diff_main)") - } + // Defaults. + // Set these on your diff_match_patch instance to override the defaults. + /** + * Number of seconds to map a diff before giving up (0 for infinity). + */ + var Diff_Timeout: Float = 1.0f - // Check for equality (speedup). - val diffs: LinkedList - if (text1 == text2) { - diffs = LinkedList() - if (text1.length != 0) { - diffs.add(Diff(Operation.EQUAL, text1)) - } - return diffs - } + /** + * Cost of an empty edit operation in terms of edit characters. + */ + private var Diff_EditCost: Short = 4 - // Trim off common prefix (speedup). - var commonlength = diff_commonPrefix(text1, text2) - val commonprefix = text1.substring(0, commonlength) - text1 = text1.substring(commonlength) - text2 = text2.substring(commonlength) + /** + * At what point is no match declared (0.0 = perfection, 1.0 = very loose). + */ + private var Match_Threshold: Float = 0.5f - // Trim off common suffix (speedup). - commonlength = diff_commonSuffix(text1, text2) - val commonsuffix = text1.substring(text1.length - commonlength) - text1 = text1.substring(0, text1.length - commonlength) - text2 = text2.substring(0, text2.length - commonlength) + /** + * How far to search for a match (0 = exact location, 1000+ = broad match). + * A match this many characters away from the expected location will add + * 1.0 to the score (0.0 is a perfect match). + */ + private var Match_Distance: Int = 1000 - // Compute the diff on the middle block. - diffs = diff_compute(text1, text2, checklines, deadline) + /** + * When deleting a large block of text (over ~64 characters), how close do + * the contents have to be to match the expected contents. (0.0 = perfection, + * 1.0 = very loose). Note that Match_Threshold controls how closely the + * end points of a delete need to match. + */ + private var Patch_DeleteThreshold: Float = 0.5f - // Restore the prefix and suffix. - if (commonprefix.length != 0) { - diffs.addFirst(Diff(Operation.EQUAL, commonprefix)) - } - if (commonsuffix.length != 0) { - diffs.addLast(Diff(Operation.EQUAL, commonsuffix)) - } + /** + * Chunk size for context length. + */ + private var Patch_Margin: Short = 4 - diff_cleanupMerge(diffs) - return diffs - } - - /** - * Find the differences between two texts. Assumes that the texts do not - * have any common prefix or suffix. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @param deadline Time when the diff should be complete by. - * @return Linked List of Diff objects. - */ - private fun diff_compute(text1: String, text2: String, checklines: Boolean, deadline: Long): LinkedList { - var diffs = LinkedList() - - if (text1.length == 0) { - // Just add some text (speedup). - diffs.add(Diff(Operation.INSERT, text2)) - return diffs - } + /** + * The number of bits in an int. + */ + private val Match_MaxBits: Short = 32 - if (text2.length == 0) { - // Just delete some text (speedup). - diffs.add(Diff(Operation.DELETE, text1)) - return diffs - } + /** + * Internal class for returning results from diff_linesToChars(). + * Other less paranoid languages just use a three-element array. + */ + protected class LinesToCharsResult( + var chars1: String, var chars2: String, + var lineArray: List + ) - val longtext = if (text1.length > text2.length) text1 else text2 - val shorttext = if (text1.length > text2.length) text2 else text1 - val i = longtext.indexOf(shorttext) - if (i != -1) { - // Shorter text is inside the longer text (speedup). - val op = if ((text1.length > text2.length)) Operation.DELETE else Operation.INSERT - diffs.add(Diff(op, longtext.substring(0, i))) - diffs.add(Diff(Operation.EQUAL, shorttext)) - diffs.add(Diff(op, longtext.substring(i + shorttext.length))) - return diffs - } - if (shorttext.length == 1) { - // Single character string. - // After the previous speedup, the character can't be an equality. - diffs.add(Diff(Operation.DELETE, text1)) - diffs.add(Diff(Operation.INSERT, text2)) - return diffs + // DIFF FUNCTIONS + /** + * The data structure representing a diff is a Linked list of Diff objects: + * {Diff(Operation.DELETE, "Hello"), Diff(Operation.INSERT, "Goodbye"), + * Diff(Operation.EQUAL, " world.")} + * which means: delete "Hello", add "Goodbye" and keep " world." + */ + enum class Operation { + DELETE, INSERT, EQUAL } - // Check to see if the problem can be split in two. - val hm = diff_halfMatch(text1, text2) - if (hm != null) { - // A half-match was found, sort out the return data. - val text1_a = hm[0] - val text1_b = hm[1] - val text2_a = hm[2] - val text2_b = hm[3] - val mid_common = hm[4] - // Send both pairs off for separate processing. - val diffs_a = diff_main( - text1_a, text2_a, - checklines, deadline - ) - val diffs_b = diff_main( - text1_b, text2_b, - checklines, deadline - ) - // Merge the results. - diffs = diffs_a - diffs.add(Diff(Operation.EQUAL, mid_common)) - diffs.addAll(diffs_b) - return diffs + /** + * Find the differences between two texts. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @return Linked List of Diff objects. + */ + /** + * Find the differences between two texts. + * Run a faster, slightly less optimal diff. + * This method allows the 'checklines' of diff_main() to be optional. + * Most of the time checklines is wanted, so default to true. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @return Linked List of Diff objects. + */ + @JvmOverloads + fun diff_main(text1: String?, text2: String?, checklines: Boolean = true): LinkedList { + // Set a deadline by which time the diff must be complete. + val deadline: Long + if (Diff_Timeout <= 0) { + deadline = Long.MAX_VALUE + } else { + deadline = System.currentTimeMillis() + (Diff_Timeout * 1000).toLong() + } + return diff_main(text1, text2, checklines, deadline) } - if ((checklines && text1.length > 100) && text2.length > 100) { - return diff_lineMode(text1, text2, deadline) - } + /** + * Find the differences between two texts. Simplifies the problem by + * stripping any common prefix or suffix off the texts before diffing. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @param deadline Time when the diff should be complete by. Used + * internally for recursive calls. Users should set DiffTimeout instead. + * @return Linked List of Diff objects. + */ + fun diff_main(text1: String?, text2: String?, checklines: Boolean, deadline: Long): LinkedList { + // Check for null inputs. + var text1 = text1 + var text2 = text2 + if (text1 == null || text2 == null) { + throw IllegalArgumentException("Null inputs. (diff_main)") + } - return diff_bisect(text1, text2, deadline) - } - - /** - * Do a quick line-level diff on both strings, then rediff the parts for - * greater accuracy. - * This speedup can produce non-minimal diffs. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param deadline Time when the diff should be complete by. - * @return Linked List of Diff objects. - */ - private fun diff_lineMode( - text1: String, text2: String, - deadline: Long - ): LinkedList { - // Scan the text on a line-by-line basis first. - var text1 = text1 - var text2 = text2 - val a = diff_linesToChars(text1, text2) - text1 = a.chars1 - text2 = a.chars2 - val linearray = a.lineArray - - val diffs = diff_main(text1, text2, false, deadline) - - // Convert the diff back to original text. - diff_charsToLines(diffs, linearray) - // Eliminate freak matches (e.g. blank lines) - diff_cleanupSemantic(diffs) - - // Rediff any replacement blocks, this time character-by-character. - // Add a dummy entry at the end. - diffs.add(Diff(Operation.EQUAL, "")) - var count_delete = 0 - var count_insert = 0 - var text_delete: String = "" - var text_insert: String = "" - val pointer = diffs.listIterator() - var thisDiff: Diff? = pointer.next() - while (thisDiff != null) { - when (thisDiff.operation) { - Operation.INSERT -> { - count_insert++ - text_insert += thisDiff.text - } - - Operation.DELETE -> { - count_delete++ - text_delete += thisDiff.text - } - - Operation.EQUAL -> { - // Upon reaching an equality, check for prior redundancies. - if (count_delete >= 1 && count_insert >= 1) { - // Delete the offending records and add the merged ones. - pointer.previous() - var j = 0 - while (j < count_delete + count_insert) { - pointer.previous() - pointer.remove() - j++ - } - for (subDiff: Diff in diff_main( - text_delete, text_insert, false, - deadline - )) { - pointer.add(subDiff) + // Check for equality (speedup). + val diffs: LinkedList + if (text1 == text2) { + diffs = LinkedList() + if (text1.length != 0) { + diffs.add(Diff(Operation.EQUAL, text1)) } - } - count_insert = 0 - count_delete = 0 - text_delete = "" - text_insert = "" + return diffs } - null -> TODO() - } - thisDiff = if (pointer.hasNext()) pointer.next() else null + // Trim off common prefix (speedup). + var commonlength = diff_commonPrefix(text1, text2) + val commonprefix = text1.substring(0, commonlength) + text1 = text1.substring(commonlength) + text2 = text2.substring(commonlength) + + // Trim off common suffix (speedup). + commonlength = diff_commonSuffix(text1, text2) + val commonsuffix = text1.substring(text1.length - commonlength) + text1 = text1.substring(0, text1.length - commonlength) + text2 = text2.substring(0, text2.length - commonlength) + + // Compute the diff on the middle block. + diffs = diff_compute(text1, text2, checklines, deadline) + + // Restore the prefix and suffix. + if (commonprefix.length != 0) { + diffs.addFirst(Diff(Operation.EQUAL, commonprefix)) + } + if (commonsuffix.length != 0) { + diffs.addLast(Diff(Operation.EQUAL, commonsuffix)) + } + + diff_cleanupMerge(diffs) + return diffs } - diffs.removeLast() // Remove the dummy entry at the end. - - return diffs - } - - /** - * Find the 'middle snake' of a diff, split the problem in two - * and return the recursively constructed diff. - * See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param deadline Time at which to bail if not yet complete. - * @return LinkedList of Diff objects. - */ - private fun diff_bisect( - text1: String, text2: String, - deadline: Long - ): LinkedList { - // Cache the text lengths to prevent multiple calls. - val text1_length = text1.length - val text2_length = text2.length - val max_d = (text1_length + text2_length + 1) / 2 - val v_offset = max_d - val v_length = 2 * max_d - val v1 = IntArray(v_length) - val v2 = IntArray(v_length) - for (x in 0 until v_length) { - v1[x] = -1 - v2[x] = -1 + + /** + * Find the differences between two texts. Assumes that the texts do not + * have any common prefix or suffix. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @param deadline Time when the diff should be complete by. + * @return Linked List of Diff objects. + */ + private fun diff_compute(text1: String, text2: String, checklines: Boolean, deadline: Long): LinkedList { + var diffs = LinkedList() + + if (text1.length == 0) { + // Just add some text (speedup). + diffs.add(Diff(Operation.INSERT, text2)) + return diffs + } + + if (text2.length == 0) { + // Just delete some text (speedup). + diffs.add(Diff(Operation.DELETE, text1)) + return diffs + } + + val longtext = if (text1.length > text2.length) text1 else text2 + val shorttext = if (text1.length > text2.length) text2 else text1 + val i = longtext.indexOf(shorttext) + if (i != -1) { + // Shorter text is inside the longer text (speedup). + val op = if ((text1.length > text2.length)) Operation.DELETE else Operation.INSERT + diffs.add(Diff(op, longtext.substring(0, i))) + diffs.add(Diff(Operation.EQUAL, shorttext)) + diffs.add(Diff(op, longtext.substring(i + shorttext.length))) + return diffs + } + + if (shorttext.length == 1) { + // Single character string. + // After the previous speedup, the character can't be an equality. + diffs.add(Diff(Operation.DELETE, text1)) + diffs.add(Diff(Operation.INSERT, text2)) + return diffs + } + + // Check to see if the problem can be split in two. + val hm = diff_halfMatch(text1, text2) + if (hm != null) { + // A half-match was found, sort out the return data. + val text1_a = hm[0] + val text1_b = hm[1] + val text2_a = hm[2] + val text2_b = hm[3] + val mid_common = hm[4] + // Send both pairs off for separate processing. + val diffs_a = diff_main( + text1_a, text2_a, + checklines, deadline + ) + val diffs_b = diff_main( + text1_b, text2_b, + checklines, deadline + ) + // Merge the results. + diffs = diffs_a + diffs.add(Diff(Operation.EQUAL, mid_common)) + diffs.addAll(diffs_b) + return diffs + } + + if ((checklines && text1.length > 100) && text2.length > 100) { + return diff_lineMode(text1, text2, deadline) + } + + return diff_bisect(text1, text2, deadline) } - v1[v_offset + 1] = 0 - v2[v_offset + 1] = 0 - val delta = text1_length - text2_length - // If the total number of characters is odd, then the front path will - // collide with the reverse path. - val front = (delta % 2 != 0) - // Offsets for start and end of k loop. - // Prevents mapping of space beyond the grid. - var k1start = 0 - var k1end = 0 - var k2start = 0 - var k2end = 0 - for (d in 0 until max_d) { - // Bail out if deadline is reached. - if (System.currentTimeMillis() > deadline) { - break - } - - // Walk the front path one step. - var k1 = -d + k1start - while (k1 <= d - k1end) { - val k1_offset = v_offset + k1 - var x1: Int - if (k1 == -d || (k1 != d && v1[k1_offset - 1] < v1[k1_offset + 1])) { - x1 = v1[k1_offset + 1] - } else { - x1 = v1[k1_offset - 1] + 1 - } - var y1 = x1 - k1 - while ((x1 < text1_length) && y1 < text2_length && text1[x1] == text2[y1]) { - x1++ - y1++ - } - v1[k1_offset] = x1 - if (x1 > text1_length) { - // Ran off the right of the graph. - k1end += 2 - } else if (y1 > text2_length) { - // Ran off the bottom of the graph. - k1start += 2 - } else if (front) { - val k2_offset = v_offset + delta - k1 - if ((k2_offset >= 0 && k2_offset < v_length) && v2[k2_offset] != -1) { - // Mirror x2 onto top-left coordinate system. - val x2 = text1_length - v2[k2_offset] - if (x1 >= x2) { - // Overlap detected. - return diff_bisectSplit(text1, text2, x1, y1, deadline) + + /** + * Do a quick line-level diff on both strings, then rediff the parts for + * greater accuracy. + * This speedup can produce non-minimal diffs. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param deadline Time when the diff should be complete by. + * @return Linked List of Diff objects. + */ + private fun diff_lineMode( + text1: String, text2: String, + deadline: Long + ): LinkedList { + // Scan the text on a line-by-line basis first. + var text1 = text1 + var text2 = text2 + val a = diff_linesToChars(text1, text2) + text1 = a.chars1 + text2 = a.chars2 + val linearray = a.lineArray + + val diffs = diff_main(text1, text2, false, deadline) + + // Convert the diff back to original text. + diff_charsToLines(diffs, linearray) + // Eliminate freak matches (e.g. blank lines) + diff_cleanupSemantic(diffs) + + // Rediff any replacement blocks, this time character-by-character. + // Add a dummy entry at the end. + diffs.add(Diff(Operation.EQUAL, "")) + var count_delete = 0 + var count_insert = 0 + var text_delete: String = "" + var text_insert: String = "" + val pointer = diffs.listIterator() + var thisDiff: Diff? = pointer.next() + while (thisDiff != null) { + when (thisDiff.operation) { + Operation.INSERT -> { + count_insert++ + text_insert += thisDiff.text + } + + Operation.DELETE -> { + count_delete++ + text_delete += thisDiff.text + } + + Operation.EQUAL -> { + // Upon reaching an equality, check for prior redundancies. + if (count_delete >= 1 && count_insert >= 1) { + // Delete the offending records and add the merged ones. + pointer.previous() + var j = 0 + while (j < count_delete + count_insert) { + pointer.previous() + pointer.remove() + j++ + } + for (subDiff: Diff in diff_main( + text_delete, text_insert, false, + deadline + )) { + pointer.add(subDiff) + } + } + count_insert = 0 + count_delete = 0 + text_delete = "" + text_insert = "" + } + + null -> TODO() } - } - } - k1 += 2 - } - - // Walk the reverse path one step. - var k2 = -d + k2start - while (k2 <= d - k2end) { - val k2_offset = v_offset + k2 - var x2: Int - if (k2 == -d || (k2 != d && v2[k2_offset - 1] < v2[k2_offset + 1])) { - x2 = v2[k2_offset + 1] - } else { - x2 = v2[k2_offset - 1] + 1 + thisDiff = if (pointer.hasNext()) pointer.next() else null } - var y2 = x2 - k2 - while ((x2 < text1_length) && y2 < text2_length && (text1[text1_length - x2 - 1] - == text2[text2_length - y2 - 1]) - ) { - x2++ - y2++ - } - v2[k2_offset] = x2 - if (x2 > text1_length) { - // Ran off the left of the graph. - k2end += 2 - } else if (y2 > text2_length) { - // Ran off the top of the graph. - k2start += 2 - } else if (!front) { - val k1_offset = v_offset + delta - k2 - if (((k1_offset >= 0) && k1_offset < v_length) && v1[k1_offset] != -1) { - val x1 = v1[k1_offset] - val y1 = v_offset + x1 - k1_offset - // Mirror x2 onto top-left coordinate system. - x2 = text1_length - x2 - if (x1 >= x2) { - // Overlap detected. - return diff_bisectSplit(text1, text2, x1, y1, deadline) + diffs.removeLast() // Remove the dummy entry at the end. + + return diffs + } + + /** + * Find the 'middle snake' of a diff, split the problem in two + * and return the recursively constructed diff. + * See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param deadline Time at which to bail if not yet complete. + * @return LinkedList of Diff objects. + */ + private fun diff_bisect( + text1: String, text2: String, + deadline: Long + ): LinkedList { + // Cache the text lengths to prevent multiple calls. + val text1_length = text1.length + val text2_length = text2.length + val max_d = (text1_length + text2_length + 1) / 2 + val v_offset = max_d + val v_length = 2 * max_d + val v1 = IntArray(v_length) + val v2 = IntArray(v_length) + for (x in 0 until v_length) { + v1[x] = -1 + v2[x] = -1 + } + v1[v_offset + 1] = 0 + v2[v_offset + 1] = 0 + val delta = text1_length - text2_length + // If the total number of characters is odd, then the front path will + // collide with the reverse path. + val front = (delta % 2 != 0) + // Offsets for start and end of k loop. + // Prevents mapping of space beyond the grid. + var k1start = 0 + var k1end = 0 + var k2start = 0 + var k2end = 0 + for (d in 0 until max_d) { + // Bail out if deadline is reached. + if (System.currentTimeMillis() > deadline) { + break + } + + // Walk the front path one step. + var k1 = -d + k1start + while (k1 <= d - k1end) { + val k1_offset = v_offset + k1 + var x1: Int + if (k1 == -d || (k1 != d && v1[k1_offset - 1] < v1[k1_offset + 1])) { + x1 = v1[k1_offset + 1] + } else { + x1 = v1[k1_offset - 1] + 1 + } + var y1 = x1 - k1 + while ((x1 < text1_length) && y1 < text2_length && text1[x1] == text2[y1]) { + x1++ + y1++ + } + v1[k1_offset] = x1 + if (x1 > text1_length) { + // Ran off the right of the graph. + k1end += 2 + } else if (y1 > text2_length) { + // Ran off the bottom of the graph. + k1start += 2 + } else if (front) { + val k2_offset = v_offset + delta - k1 + if ((k2_offset >= 0 && k2_offset < v_length) && v2[k2_offset] != -1) { + // Mirror x2 onto top-left coordinate system. + val x2 = text1_length - v2[k2_offset] + if (x1 >= x2) { + // Overlap detected. + return diff_bisectSplit(text1, text2, x1, y1, deadline) + } + } + } + k1 += 2 + } + + // Walk the reverse path one step. + var k2 = -d + k2start + while (k2 <= d - k2end) { + val k2_offset = v_offset + k2 + var x2: Int + if (k2 == -d || (k2 != d && v2[k2_offset - 1] < v2[k2_offset + 1])) { + x2 = v2[k2_offset + 1] + } else { + x2 = v2[k2_offset - 1] + 1 + } + var y2 = x2 - k2 + while ((x2 < text1_length) && y2 < text2_length && (text1[text1_length - x2 - 1] + == text2[text2_length - y2 - 1]) + ) { + x2++ + y2++ + } + v2[k2_offset] = x2 + if (x2 > text1_length) { + // Ran off the left of the graph. + k2end += 2 + } else if (y2 > text2_length) { + // Ran off the top of the graph. + k2start += 2 + } else if (!front) { + val k1_offset = v_offset + delta - k2 + if (((k1_offset >= 0) && k1_offset < v_length) && v1[k1_offset] != -1) { + val x1 = v1[k1_offset] + val y1 = v_offset + x1 - k1_offset + // Mirror x2 onto top-left coordinate system. + x2 = text1_length - x2 + if (x1 >= x2) { + // Overlap detected. + return diff_bisectSplit(text1, text2, x1, y1, deadline) + } + } + } + k2 += 2 } - } } - k2 += 2 - } - } - // Diff took too long and hit the deadline or - // number of diffs equals number of characters, no commonality at all. - val diffs = LinkedList() - diffs.add(Diff(Operation.DELETE, text1)) - diffs.add(Diff(Operation.INSERT, text2)) - return diffs - } - - /** - * Given the location of the 'middle snake', split the diff in two parts - * and recurse. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param x Index of split point in text1. - * @param y Index of split point in text2. - * @param deadline Time at which to bail if not yet complete. - * @return LinkedList of Diff objects. - */ - private fun diff_bisectSplit( - text1: String, text2: String, - x: Int, y: Int, deadline: Long - ): LinkedList { - val text1a = text1.substring(0, x) - val text2a = text2.substring(0, y) - val text1b = text1.substring(x) - val text2b = text2.substring(y) - - // Compute both diffs serially. - val diffs = diff_main(text1a, text2a, false, deadline) - val diffsb = diff_main(text1b, text2b, false, deadline) - - diffs.addAll(diffsb) - return diffs - } - - /** - * Split two texts into a list of strings. Reduce the texts to a string of - * hashes where each Unicode character represents one line. - * @param text1 First string. - * @param text2 Second string. - * @return An object containing the encoded text1, the encoded text2 and - * the List of unique strings. The zeroth element of the List of - * unique strings is intentionally blank. - */ - private fun diff_linesToChars(text1: String, text2: String): LinesToCharsResult { - val lineArray: MutableList = ArrayList() - val lineHash: MutableMap = HashMap() - - // e.g. linearray[4] == "Hello\n" - // e.g. linehash.get("Hello\n") == 4 - - // "\x00" is a valid character, but various debuggers don't like it. - // So we'll insert a junk entry to avoid generating a null character. - lineArray.add("") - - // Allocate 2/3rds of the space for text1, the rest for text2. - val chars1 = diff_linesToCharsMunge(text1, lineArray, lineHash, 40000) - val chars2 = diff_linesToCharsMunge(text2, lineArray, lineHash, 65535) - return LinesToCharsResult(chars1, chars2, lineArray) - } - - /** - * Split a text into a list of strings. Reduce the texts to a string of - * hashes where each Unicode character represents one line. - * @param text String to encode. - * @param lineArray List of unique strings. - * @param lineHash Map of strings to indices. - * @param maxLines Maximum length of lineArray. - * @return Encoded string. - */ - private fun diff_linesToCharsMunge( - text: String, lineArray: MutableList, - lineHash: MutableMap, maxLines: Int - ): String { - var lineStart = 0 - var lineEnd = -1 - var line: String - val chars = StringBuilder() - // Walk the text, pulling out a substring for each line. - // text.split('\n') would would temporarily double our memory footprint. - // Modifying text would create many large strings to garbage collect. - while (lineEnd < text.length - 1) { - lineEnd = text.indexOf('\n', lineStart) - if (lineEnd == -1) { - lineEnd = text.length - 1 - } - line = text.substring(lineStart, lineEnd + 1) - - if (lineHash.containsKey(line)) { - chars.append((lineHash[line] as Int).toChar().toString()) - } else { - if (lineArray.size == maxLines) { - // Bail out at 65535 because - // String.valueOf((char) 65536).equals(String.valueOf(((char) 0))) - line = text.substring(lineStart) - lineEnd = text.length - } - lineArray.add(line) - lineHash[line] = lineArray.size - 1 - chars.append((lineArray.size - 1).toChar().toString()) - } - lineStart = lineEnd + 1 - } - return chars.toString() - } - - /** - * Rehydrate the text in a diff from a string of line hashes to real lines of - * text. - * @param diffs List of Diff objects. - * @param lineArray List of unique strings. - */ - private fun diff_charsToLines( - diffs: List, - lineArray: List - ) { - var text: StringBuilder - for (diff: Diff in diffs) { - text = StringBuilder() - for (j in 0 until diff.text!!.length) { - text.append(lineArray[diff.text!![j].code]) - } - diff.text = text.toString() - } - } - - /** - * Determine the common prefix of two strings - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the start of each string. - */ - fun diff_commonPrefix(text1: String?, text2: String?): Int { - // Performance analysis: https://neil.fraser.name/news/2007/10/09/ - val n = min(text1!!.length.toDouble(), text2!!.length.toDouble()).toInt() - for (i in 0 until n) { - if (text1[i] != text2[i]) { - return i - } - } - return n - } - - /** - * Determine the common suffix of two strings - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the end of each string. - */ - fun diff_commonSuffix(text1: String?, text2: String?): Int { - // Performance analysis: https://neil.fraser.name/news/2007/10/09/ - val text1_length = text1!!.length - val text2_length = text2!!.length - val n = min(text1_length.toDouble(), text2_length.toDouble()).toInt() - for (i in 1..n) { - if (text1[text1_length - i] != text2[text2_length - i]) { - return i - 1 - } - } - return n - } - - /** - * Determine if the suffix of one string is the prefix of another. - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the end of the first - * string and the start of the second string. - */ - private fun diff_commonOverlap(text1: String?, text2: String?): Int { - // Cache the text lengths to prevent multiple calls. - var text1 = text1 - var text2 = text2 - val text1_length = text1!!.length - val text2_length = text2!!.length - // Eliminate the null case. - if (text1_length == 0 || text2_length == 0) { - return 0 - } - // Truncate the longer string. - if (text1_length > text2_length) { - text1 = text1.substring(text1_length - text2_length) - } else if (text1_length < text2_length) { - text2 = text2.substring(0, text1_length) - } - val text_length = min(text1_length.toDouble(), text2_length.toDouble()).toInt() - // Quick check for the worst case. - if (text1 == text2) { - return text_length + // Diff took too long and hit the deadline or + // number of diffs equals number of characters, no commonality at all. + val diffs = LinkedList() + diffs.add(Diff(Operation.DELETE, text1)) + diffs.add(Diff(Operation.INSERT, text2)) + return diffs } - // Start by looking for a single character match - // and increase length until no match is found. - // Performance analysis: https://neil.fraser.name/news/2010/11/04/ - var best = 0 - var length = 1 - while (true) { - val pattern = text1.substring(text_length - length) - val found = text2.indexOf(pattern) - if (found == -1) { - return best - } - length += found - if (found == 0 || text1.substring(text_length - length) == text2.substring(0, length)) { - best = length - length++ - } - } - } - - /** - * Do the two texts share a substring which is at least half the length of - * the longer text? - * This speedup can produce non-minimal diffs. - * @param text1 First string. - * @param text2 Second string. - * @return Five element String array, containing the prefix of text1, the - * suffix of text1, the prefix of text2, the suffix of text2 and the - * common middle. Or null if there was no match. - */ - private fun diff_halfMatch(text1: String, text2: String): Array? { - if (Diff_Timeout <= 0) { - // Don't risk returning a non-optimal diff if we have unlimited time. - return null - } - val longtext = if (text1.length > text2.length) text1 else text2 - val shorttext = if (text1.length > text2.length) text2 else text1 - if (longtext.length < 4 || shorttext.length * 2 < longtext.length) { - return null // Pointless. + /** + * Given the location of the 'middle snake', split the diff in two parts + * and recurse. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param x Index of split point in text1. + * @param y Index of split point in text2. + * @param deadline Time at which to bail if not yet complete. + * @return LinkedList of Diff objects. + */ + private fun diff_bisectSplit( + text1: String, text2: String, + x: Int, y: Int, deadline: Long + ): LinkedList { + val text1a = text1.substring(0, x) + val text2a = text2.substring(0, y) + val text1b = text1.substring(x) + val text2b = text2.substring(y) + + // Compute both diffs serially. + val diffs = diff_main(text1a, text2a, false, deadline) + val diffsb = diff_main(text1b, text2b, false, deadline) + + diffs.addAll(diffsb) + return diffs } - // First check if the second quarter is the seed for a half-match. - val hm1 = diff_halfMatchI( - longtext, shorttext, - (longtext.length + 3) / 4 - ) - // Check again based on the third quarter. - val hm2 = diff_halfMatchI( - longtext, shorttext, - (longtext.length + 1) / 2 - ) - val hm: Array? - if (hm1 == null && hm2 == null) { - return null - } else if (hm2 == null) { - hm = hm1 - } else if (hm1 == null) { - hm = hm2 - } else { - // Both matched. Select the longest. - hm = if (hm1[4].length > hm2[4].length) hm1 else hm2 + /** + * Split two texts into a list of strings. Reduce the texts to a string of + * hashes where each Unicode character represents one line. + * @param text1 First string. + * @param text2 Second string. + * @return An object containing the encoded text1, the encoded text2 and + * the List of unique strings. The zeroth element of the List of + * unique strings is intentionally blank. + */ + private fun diff_linesToChars(text1: String, text2: String): LinesToCharsResult { + val lineArray: MutableList = ArrayList() + val lineHash: MutableMap = HashMap() + + // e.g. linearray[4] == "Hello\n" + // e.g. linehash.get("Hello\n") == 4 + + // "\x00" is a valid character, but various debuggers don't like it. + // So we'll insert a junk entry to avoid generating a null character. + lineArray.add("") + + // Allocate 2/3rds of the space for text1, the rest for text2. + val chars1 = diff_linesToCharsMunge(text1, lineArray, lineHash, 40000) + val chars2 = diff_linesToCharsMunge(text2, lineArray, lineHash, 65535) + return LinesToCharsResult(chars1, chars2, lineArray) } - // A half-match was found, sort out the return data. - if (text1.length > text2.length) { - return hm - //return new String[]{hm[0], hm[1], hm[2], hm[3], hm[4]}; - } else { - return arrayOf(hm!![2], hm[3], hm[0], hm[1], hm[4]) + /** + * Split a text into a list of strings. Reduce the texts to a string of + * hashes where each Unicode character represents one line. + * @param text String to encode. + * @param lineArray List of unique strings. + * @param lineHash Map of strings to indices. + * @param maxLines Maximum length of lineArray. + * @return Encoded string. + */ + private fun diff_linesToCharsMunge( + text: String, lineArray: MutableList, + lineHash: MutableMap, maxLines: Int + ): String { + var lineStart = 0 + var lineEnd = -1 + var line: String + val chars = StringBuilder() + // Walk the text, pulling out a substring for each line. + // text.split('\n') would would temporarily double our memory footprint. + // Modifying text would create many large strings to garbage collect. + while (lineEnd < text.length - 1) { + lineEnd = text.indexOf('\n', lineStart) + if (lineEnd == -1) { + lineEnd = text.length - 1 + } + line = text.substring(lineStart, lineEnd + 1) + + if (lineHash.containsKey(line)) { + chars.append((lineHash[line] as Int).toChar().toString()) + } else { + if (lineArray.size == maxLines) { + // Bail out at 65535 because + // String.valueOf((char) 65536).equals(String.valueOf(((char) 0))) + line = text.substring(lineStart) + lineEnd = text.length + } + lineArray.add(line) + lineHash[line] = lineArray.size - 1 + chars.append((lineArray.size - 1).toChar().toString()) + } + lineStart = lineEnd + 1 + } + return chars.toString() } - } - - /** - * Does a substring of shorttext exist within longtext such that the - * substring is at least half the length of longtext? - * @param longtext Longer string. - * @param shorttext Shorter string. - * @param i Start index of quarter length substring within longtext. - * @return Five element String array, containing the prefix of longtext, the - * suffix of longtext, the prefix of shorttext, the suffix of shorttext - * and the common middle. Or null if there was no match. - */ - private fun diff_halfMatchI(longtext: String, shorttext: String, i: Int): Array? { - // Start with a 1/4 length substring at position i as a seed. - val seed = longtext.substring(i, i + longtext.length / 4) - var j = -1 - var best_common = "" - var best_longtext_a = "" - var best_longtext_b = "" - var best_shorttext_a = "" - var best_shorttext_b = "" - while ((shorttext.indexOf(seed, j + 1).also { j = it }) != -1) { - val prefixLength = diff_commonPrefix( - longtext.substring(i), - shorttext.substring(j) - ) - val suffixLength = diff_commonSuffix( - longtext.substring(0, i), - shorttext.substring(0, j) - ) - if (best_common.length < suffixLength + prefixLength) { - best_common = (shorttext.substring(j - suffixLength, j) - + shorttext.substring(j, j + prefixLength)) - best_longtext_a = longtext.substring(0, i - suffixLength) - best_longtext_b = longtext.substring(i + prefixLength) - best_shorttext_a = shorttext.substring(0, j - suffixLength) - best_shorttext_b = shorttext.substring(j + prefixLength) - } + + /** + * Rehydrate the text in a diff from a string of line hashes to real lines of + * text. + * @param diffs List of Diff objects. + * @param lineArray List of unique strings. + */ + private fun diff_charsToLines( + diffs: List, + lineArray: List + ) { + var text: StringBuilder + for (diff: Diff in diffs) { + text = StringBuilder() + for (j in 0 until diff.text!!.length) { + text.append(lineArray[diff.text!![j].code]) + } + diff.text = text.toString() + } } - if (best_common.length * 2 >= longtext.length) { - return arrayOf( - best_longtext_a, best_longtext_b, - best_shorttext_a, best_shorttext_b, best_common - ) - } else { - return null + + /** + * Determine the common prefix of two strings + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the start of each string. + */ + fun diff_commonPrefix(text1: String?, text2: String?): Int { + // Performance analysis: https://neil.fraser.name/news/2007/10/09/ + val n = min(text1!!.length.toDouble(), text2!!.length.toDouble()).toInt() + for (i in 0 until n) { + if (text1[i] != text2[i]) { + return i + } + } + return n } - } - - /** - * Reduce the number of edits by eliminating semantically trivial equalities. - * @param diffs LinkedList of Diff objects. - */ - fun diff_cleanupSemantic(diffs: LinkedList) { - if (diffs.isEmpty()) { - return + + /** + * Determine the common suffix of two strings + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the end of each string. + */ + fun diff_commonSuffix(text1: String?, text2: String?): Int { + // Performance analysis: https://neil.fraser.name/news/2007/10/09/ + val text1_length = text1!!.length + val text2_length = text2!!.length + val n = min(text1_length.toDouble(), text2_length.toDouble()).toInt() + for (i in 1..n) { + if (text1[text1_length - i] != text2[text2_length - i]) { + return i - 1 + } + } + return n } - var changes = false - val equalities = ArrayDeque() // Double-ended queue of qualities. - var lastEquality: String? = null // Always equal to equalities.peek().text - var pointer = diffs.listIterator() - // Number of characters that changed prior to the equality. - var length_insertions1 = 0 - var length_deletions1 = 0 - // Number of characters that changed after the equality. - var length_insertions2 = 0 - var length_deletions2 = 0 - var thisDiff: Diff? = pointer.next() - while (thisDiff != null) { - if (thisDiff.operation == Operation.EQUAL) { - // Equality found. - equalities.add(thisDiff) - length_insertions1 = length_insertions2 - length_deletions1 = length_deletions2 - length_insertions2 = 0 - length_deletions2 = 0 - lastEquality = thisDiff.text - } else { - // An insertion or deletion. - if (thisDiff.operation == Operation.INSERT) { - length_insertions2 += thisDiff.text!!.length - } else { - length_deletions2 += thisDiff.text!!.length - } - // Eliminate an equality that is smaller or equal to the edits on both - // sides of it. - if (lastEquality != null && (lastEquality.length - <= max(length_insertions1.toDouble(), length_deletions1.toDouble())) - && (lastEquality.length - <= max(length_insertions2.toDouble(), length_deletions2.toDouble())) - ) { - //System.out.println("Splitting: '" + lastEquality + "'"); - // Walk back to offending equality. - while (thisDiff !== equalities.peek()) { - thisDiff = pointer.previous() - } - pointer.next() - - // Replace equality with a delete. - pointer.set(Diff(Operation.DELETE, lastEquality)) - // Insert a corresponding an insert. - pointer.add(Diff(Operation.INSERT, lastEquality)) - - equalities.pop() // Throw away the equality we just deleted. - if (!equalities.isEmpty()) { - // Throw away the previous equality (it needs to be reevaluated). - equalities.pop() - } - if (equalities.isEmpty()) { - // There are no previous equalities, walk back to the start. - while (pointer.hasPrevious()) { - pointer.previous() + + /** + * Determine if the suffix of one string is the prefix of another. + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the end of the first + * string and the start of the second string. + */ + private fun diff_commonOverlap(text1: String?, text2: String?): Int { + // Cache the text lengths to prevent multiple calls. + var text1 = text1 + var text2 = text2 + val text1_length = text1!!.length + val text2_length = text2!!.length + // Eliminate the null case. + if (text1_length == 0 || text2_length == 0) { + return 0 + } + // Truncate the longer string. + if (text1_length > text2_length) { + text1 = text1.substring(text1_length - text2_length) + } else if (text1_length < text2_length) { + text2 = text2.substring(0, text1_length) + } + val text_length = min(text1_length.toDouble(), text2_length.toDouble()).toInt() + // Quick check for the worst case. + if (text1 == text2) { + return text_length + } + + // Start by looking for a single character match + // and increase length until no match is found. + // Performance analysis: https://neil.fraser.name/news/2010/11/04/ + var best = 0 + var length = 1 + while (true) { + val pattern = text1.substring(text_length - length) + val found = text2.indexOf(pattern) + if (found == -1) { + return best } - } else { - // There is a safe equality we can fall back to. - thisDiff = equalities.peek() - while (thisDiff !== pointer.previous()) { - // Intentionally empty loop. + length += found + if (found == 0 || text1.substring(text_length - length) == text2.substring(0, length)) { + best = length + length++ } - } - - length_insertions1 = 0 // Reset the counters. - length_insertions2 = 0 - length_deletions1 = 0 - length_deletions2 = 0 - lastEquality = null - changes = true } - } - thisDiff = if (pointer.hasNext()) pointer.next() else null } - // Normalize the diff. - if (changes) { - diff_cleanupMerge(diffs) - } - diff_cleanupSemanticLossless(diffs) - - // Find any overlaps between deletions and insertions. - // e.g: abcxxxxxxdef - // -> abcxxxdef - // e.g: xxxabcdefxxx - // -> defxxxabc - // Only extract an overlap if it is as big as the edit ahead or behind it. - pointer = diffs.listIterator() - var prevDiff: Diff? = null - thisDiff = null - if (pointer.hasNext()) { - prevDiff = pointer.next() - if (pointer.hasNext()) { - thisDiff = pointer.next() - } + /** + * Do the two texts share a substring which is at least half the length of + * the longer text? + * This speedup can produce non-minimal diffs. + * @param text1 First string. + * @param text2 Second string. + * @return Five element String array, containing the prefix of text1, the + * suffix of text1, the prefix of text2, the suffix of text2 and the + * common middle. Or null if there was no match. + */ + private fun diff_halfMatch(text1: String, text2: String): Array? { + if (Diff_Timeout <= 0) { + // Don't risk returning a non-optimal diff if we have unlimited time. + return null + } + val longtext = if (text1.length > text2.length) text1 else text2 + val shorttext = if (text1.length > text2.length) text2 else text1 + if (longtext.length < 4 || shorttext.length * 2 < longtext.length) { + return null // Pointless. + } + + // First check if the second quarter is the seed for a half-match. + val hm1 = diff_halfMatchI( + longtext, shorttext, + (longtext.length + 3) / 4 + ) + // Check again based on the third quarter. + val hm2 = diff_halfMatchI( + longtext, shorttext, + (longtext.length + 1) / 2 + ) + val hm: Array? + if (hm1 == null && hm2 == null) { + return null + } else if (hm2 == null) { + hm = hm1 + } else if (hm1 == null) { + hm = hm2 + } else { + // Both matched. Select the longest. + hm = if (hm1[4].length > hm2[4].length) hm1 else hm2 + } + + // A half-match was found, sort out the return data. + if (text1.length > text2.length) { + return hm + //return new String[]{hm[0], hm[1], hm[2], hm[3], hm[4]}; + } else { + return arrayOf(hm!![2], hm[3], hm[0], hm[1], hm[4]) + } } - while (thisDiff != null) { - if (prevDiff!!.operation == Operation.DELETE && - thisDiff.operation == Operation.INSERT - ) { - val deletion = prevDiff.text - val insertion = thisDiff.text - val overlap_length1 = this.diff_commonOverlap(deletion, insertion) - val overlap_length2 = this.diff_commonOverlap(insertion, deletion) - if (overlap_length1 >= overlap_length2) { - if (overlap_length1 >= deletion!!.length / 2.0 || - overlap_length1 >= insertion!!.length / 2.0 - ) { - // Overlap found. Insert an equality and trim the surrounding edits. - pointer.previous() - pointer.add( - Diff( - Operation.EQUAL, - insertion!!.substring(0, overlap_length1) - ) + + /** + * Does a substring of shorttext exist within longtext such that the + * substring is at least half the length of longtext? + * @param longtext Longer string. + * @param shorttext Shorter string. + * @param i Start index of quarter length substring within longtext. + * @return Five element String array, containing the prefix of longtext, the + * suffix of longtext, the prefix of shorttext, the suffix of shorttext + * and the common middle. Or null if there was no match. + */ + private fun diff_halfMatchI(longtext: String, shorttext: String, i: Int): Array? { + // Start with a 1/4 length substring at position i as a seed. + val seed = longtext.substring(i, i + longtext.length / 4) + var j = -1 + var best_common = "" + var best_longtext_a = "" + var best_longtext_b = "" + var best_shorttext_a = "" + var best_shorttext_b = "" + while ((shorttext.indexOf(seed, j + 1).also { j = it }) != -1) { + val prefixLength = diff_commonPrefix( + longtext.substring(i), + shorttext.substring(j) ) - prevDiff.text = - deletion.substring(0, deletion.length - overlap_length1) - thisDiff.text = insertion.substring(overlap_length1) - // pointer.add inserts the element before the cursor, so there is - // no need to step past the new element. - } - } else { - if (overlap_length2 >= deletion!!.length / 2.0 || - overlap_length2 >= insertion!!.length / 2.0 - ) { - // Reverse overlap found. - // Insert an equality and swap and trim the surrounding edits. - pointer.previous() - pointer.add( - Diff( - Operation.EQUAL, - deletion.substring(0, overlap_length2) - ) + val suffixLength = diff_commonSuffix( + longtext.substring(0, i), + shorttext.substring(0, j) + ) + if (best_common.length < suffixLength + prefixLength) { + best_common = (shorttext.substring(j - suffixLength, j) + + shorttext.substring(j, j + prefixLength)) + best_longtext_a = longtext.substring(0, i - suffixLength) + best_longtext_b = longtext.substring(i + prefixLength) + best_shorttext_a = shorttext.substring(0, j - suffixLength) + best_shorttext_b = shorttext.substring(j + prefixLength) + } + } + if (best_common.length * 2 >= longtext.length) { + return arrayOf( + best_longtext_a, best_longtext_b, + best_shorttext_a, best_shorttext_b, best_common ) - prevDiff.operation = Operation.INSERT - prevDiff.text = - insertion!!.substring(0, insertion.length - overlap_length2) - thisDiff.operation = Operation.DELETE - thisDiff.text = deletion.substring(overlap_length2) - // pointer.add inserts the element before the cursor, so there is - // no need to step past the new element. - } + } else { + return null } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - prevDiff = thisDiff - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - } - - /** - * Look for single edits surrounded on both sides by equalities - * which can be shifted sideways to align the edit to a word boundary. - * e.g: The cat came. -> The cat came. - * @param diffs LinkedList of Diff objects. - */ - private fun diff_cleanupSemanticLossless(diffs: LinkedList) { - var equality1: String - var edit: String - var equality2: String - var commonString: String - var commonOffset: Int - var score: Int - var bestScore: Int - var bestEquality1: String? - var bestEdit: String? - var bestEquality2: String? - // Create a new iterator at the start. - val pointer = diffs.listIterator() - var prevDiff = if (pointer.hasNext()) pointer.next() else null - var thisDiff = if (pointer.hasNext()) pointer.next() else null - var nextDiff = if (pointer.hasNext()) pointer.next() else null - // Intentionally ignore the first and last element (don't need checking). - while (nextDiff != null) { - if (prevDiff!!.operation == Operation.EQUAL && - nextDiff.operation == Operation.EQUAL - ) { - // This is a single edit surrounded by equalities. - equality1 = prevDiff.text!! - edit = thisDiff!!.text!! - equality2 = nextDiff.text!! - - // First, shift the edit as far left as possible. - commonOffset = diff_commonSuffix(equality1, edit) - if (commonOffset != 0) { - commonString = edit.substring(edit.length - commonOffset) - equality1 = equality1.substring(0, equality1.length - commonOffset) - edit = commonString + edit.substring(0, edit.length - commonOffset) - equality2 = commonString + equality2 - } - - // Second, step character by character right, looking for the best fit. - bestEquality1 = equality1 - bestEdit = edit - bestEquality2 = equality2 - bestScore = (diff_cleanupSemanticScore(equality1, edit) - + diff_cleanupSemanticScore(edit, equality2)) - while (((edit.length != 0) && equality2.length != 0) && edit[0] == equality2[0]) { - equality1 += edit[0] - edit = edit.substring(1) + equality2[0] - equality2 = equality2.substring(1) - score = (diff_cleanupSemanticScore(equality1, edit) - + diff_cleanupSemanticScore(edit, equality2)) - // The >= encourages trailing rather than leading whitespace on edits. - if (score >= bestScore) { - bestScore = score - bestEquality1 = equality1 - bestEdit = edit - bestEquality2 = equality2 - } - } - - if (prevDiff.text != bestEquality1) { - // We have an improvement, save it back to the diff. - if (bestEquality1!!.length != 0) { - prevDiff.text = bestEquality1 - } else { - pointer.previous() // Walk past nextDiff. - pointer.previous() // Walk past thisDiff. - pointer.previous() // Walk past prevDiff. - pointer.remove() // Delete prevDiff. - pointer.next() // Walk past thisDiff. - pointer.next() // Walk past nextDiff. - } - thisDiff.text = bestEdit - if (bestEquality2!!.length != 0) { - nextDiff.text = bestEquality2 - } else { - pointer.remove() // Delete nextDiff. - nextDiff = thisDiff - thisDiff = prevDiff - } - } - } - prevDiff = thisDiff - thisDiff = nextDiff - nextDiff = if (pointer.hasNext()) pointer.next() else null - } - } - - /** - * Given two strings, compute a score representing whether the internal - * boundary falls on logical boundaries. - * Scores range from 6 (best) to 0 (worst). - * @param one First string. - * @param two Second string. - * @return The score. - */ - private fun diff_cleanupSemanticScore(one: String?, two: String?): Int { - if (one!!.length == 0 || two!!.length == 0) { - // Edges are the best. - return 6 } - // Each port of this function behaves slightly differently due to - // subtle differences in each language's definition of things like - // 'whitespace'. Since this function's purpose is largely cosmetic, - // the choice has been made to use each language's native features - // rather than force total conformity. - val char1 = one[one.length - 1] - val char2 = two!![0] - val nonAlphaNumeric1 = !Character.isLetterOrDigit(char1) - val nonAlphaNumeric2 = !Character.isLetterOrDigit(char2) - val whitespace1 = nonAlphaNumeric1 && Character.isWhitespace(char1) - val whitespace2 = nonAlphaNumeric2 && Character.isWhitespace(char2) - val lineBreak1 = (whitespace1 - && Character.getType(char1) == Character.CONTROL.toInt()) - val lineBreak2 = (whitespace2 - && Character.getType(char2) == Character.CONTROL.toInt()) - val blankLine1 = lineBreak1 && BLANKLINEEND.matcher(one).find() - val blankLine2 = lineBreak2 && BLANKLINESTART.matcher(two).find() - - if (blankLine1 || blankLine2) { - // Five points for blank lines. - return 5 - } else if (lineBreak1 || lineBreak2) { - // Four points for line breaks. - return 4 - } else if (nonAlphaNumeric1 && !whitespace1 && whitespace2) { - // Three points for end of sentences. - return 3 - } else if (whitespace1 || whitespace2) { - // Two points for whitespace. - return 2 - } else if (nonAlphaNumeric1 || nonAlphaNumeric2) { - // One point for non-alphanumeric. - return 1 - } - return 0 - } - - // Define some regex patterns for matching boundaries. - private val BLANKLINEEND - : Pattern = Pattern.compile("\\n\\r?\\n\\Z", Pattern.DOTALL) - private val BLANKLINESTART - : Pattern = Pattern.compile("\\A\\r?\\n\\r?\\n", Pattern.DOTALL) - - /** - * Reduce the number of edits by eliminating operationally trivial equalities. - * @param diffs LinkedList of Diff objects. - */ - fun diff_cleanupEfficiency(diffs: LinkedList) { - if (diffs.isEmpty()) { - return - } - var changes = false - val equalities = ArrayDeque() // Double-ended queue of equalities. - var lastEquality: String? = null // Always equal to equalities.peek().text - val pointer = diffs.listIterator() - // Is there an insertion operation before the last equality. - var pre_ins = false - // Is there a deletion operation before the last equality. - var pre_del = false - // Is there an insertion operation after the last equality. - var post_ins = false - // Is there a deletion operation after the last equality. - var post_del = false - var thisDiff: Diff? = pointer.next() - var safeDiff = thisDiff // The last Diff that is known to be unsplittable. - while (thisDiff != null) { - if (thisDiff.operation == Operation.EQUAL) { - // Equality found. - if (thisDiff.text!!.length < Diff_EditCost && (post_ins || post_del)) { - // Candidate found. - equalities.push(thisDiff) - pre_ins = post_ins - pre_del = post_del - lastEquality = thisDiff.text - } else { - // Not a candidate, and can never become one. - equalities.clear() - lastEquality = null - safeDiff = thisDiff - } - post_del = false - post_ins = post_del - } else { - // An insertion or deletion. - if (thisDiff.operation == Operation.DELETE) { - post_del = true - } else { - post_ins = true + /** + * Reduce the number of edits by eliminating semantically trivial equalities. + * @param diffs LinkedList of Diff objects. + */ + fun diff_cleanupSemantic(diffs: LinkedList) { + if (diffs.isEmpty()) { + return } - /* - * Five types to be split: - * ABXYCD - * AXCD - * ABXC - * AXCD - * ABXC - */ - if (lastEquality != null - && ((pre_ins && pre_del && post_ins && post_del) - || ((lastEquality.length < Diff_EditCost / 2) - && ((if (pre_ins) 1 else 0) + (if (pre_del) 1 else 0) - + (if (post_ins) 1 else 0) + (if (post_del) 1 else 0)) == 3)) - ) { - //System.out.println("Splitting: '" + lastEquality + "'"); - // Walk back to offending equality. - while (thisDiff !== equalities.peek()) { - thisDiff = pointer.previous() - } - pointer.next() - - // Replace equality with a delete. - pointer.set(Diff(Operation.DELETE, lastEquality)) - // Insert a corresponding an insert. - pointer.add(Diff(Operation.INSERT, lastEquality).also { - thisDiff = it - }) - - equalities.pop() // Throw away the equality we just deleted. - lastEquality = null - if (pre_ins && pre_del) { - // No changes made which could affect previous entry, keep going. - post_del = true - post_ins = post_del - equalities.clear() - safeDiff = thisDiff - } else { - if (!equalities.isEmpty()) { - // Throw away the previous equality (it needs to be reevaluated). - equalities.pop() - } - if (equalities.isEmpty()) { - // There are no previous questionable equalities, - // walk back to the last known safe diff. - thisDiff = safeDiff + var changes = false + val equalities = ArrayDeque() // Double-ended queue of qualities. + var lastEquality: String? = null // Always equal to equalities.peek().text + var pointer = diffs.listIterator() + // Number of characters that changed prior to the equality. + var length_insertions1 = 0 + var length_deletions1 = 0 + // Number of characters that changed after the equality. + var length_insertions2 = 0 + var length_deletions2 = 0 + var thisDiff: Diff? = pointer.next() + while (thisDiff != null) { + if (thisDiff.operation == Operation.EQUAL) { + // Equality found. + equalities.add(thisDiff) + length_insertions1 = length_insertions2 + length_deletions1 = length_deletions2 + length_insertions2 = 0 + length_deletions2 = 0 + lastEquality = thisDiff.text } else { - // There is an equality we can fall back to. - thisDiff = equalities.peek() - } - while (thisDiff !== pointer.previous()) { - // Intentionally empty loop. + // An insertion or deletion. + if (thisDiff.operation == Operation.INSERT) { + length_insertions2 += thisDiff.text!!.length + } else { + length_deletions2 += thisDiff.text!!.length + } + // Eliminate an equality that is smaller or equal to the edits on both + // sides of it. + if (lastEquality != null && (lastEquality.length + <= max(length_insertions1.toDouble(), length_deletions1.toDouble())) + && (lastEquality.length + <= max(length_insertions2.toDouble(), length_deletions2.toDouble())) + ) { + //System.out.println("Splitting: '" + lastEquality + "'"); + // Walk back to offending equality. + while (thisDiff !== equalities.peek()) { + thisDiff = pointer.previous() + } + pointer.next() + + // Replace equality with a delete. + pointer.set(Diff(Operation.DELETE, lastEquality)) + // Insert a corresponding an insert. + pointer.add(Diff(Operation.INSERT, lastEquality)) + + equalities.pop() // Throw away the equality we just deleted. + if (!equalities.isEmpty()) { + // Throw away the previous equality (it needs to be reevaluated). + equalities.pop() + } + if (equalities.isEmpty()) { + // There are no previous equalities, walk back to the start. + while (pointer.hasPrevious()) { + pointer.previous() + } + } else { + // There is a safe equality we can fall back to. + thisDiff = equalities.peek() + while (thisDiff !== pointer.previous()) { + // Intentionally empty loop. + } + } + + length_insertions1 = 0 // Reset the counters. + length_insertions2 = 0 + length_deletions1 = 0 + length_deletions2 = 0 + lastEquality = null + changes = true + } } - post_del = false - post_ins = post_del - } - - changes = true + thisDiff = if (pointer.hasNext()) pointer.next() else null } - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - if (changes) { - diff_cleanupMerge(diffs) - } - } - - /** - * Reorder and merge like edit sections. Merge equalities. - * Any edit section can move as long as it doesn't cross an equality. - * @param diffs LinkedList of Diff objects. - */ - private fun diff_cleanupMerge(diffs: LinkedList) { - diffs.add(Diff(Operation.EQUAL, "")) // Add a dummy entry at the end. - var pointer = diffs.listIterator() - var count_delete = 0 - var count_insert = 0 - var text_delete: String? = "" - var text_insert: String? = "" - var thisDiff: Diff? = pointer.next() - var prevEqual: Diff? = null - var commonlength: Int - while (thisDiff != null) { - when (thisDiff.operation) { - Operation.INSERT -> { - count_insert++ - text_insert += thisDiff.text - prevEqual = null - } - - Operation.DELETE -> { - count_delete++ - text_delete += thisDiff.text - prevEqual = null - } - - Operation.EQUAL -> { - if (count_delete + count_insert > 1) { - val both_types = count_delete != 0 && count_insert != 0 - // Delete the offending records. - pointer.previous() // Reverse direction. - while (count_delete-- > 0) { - pointer.previous() - pointer.remove() - } - while (count_insert-- > 0) { - pointer.previous() - pointer.remove() + // Normalize the diff. + if (changes) { + diff_cleanupMerge(diffs) + } + diff_cleanupSemanticLossless(diffs) + + // Find any overlaps between deletions and insertions. + // e.g: abcxxxxxxdef + // -> abcxxxdef + // e.g: xxxabcdefxxx + // -> defxxxabc + // Only extract an overlap if it is as big as the edit ahead or behind it. + pointer = diffs.listIterator() + var prevDiff: Diff? = null + thisDiff = null + if (pointer.hasNext()) { + prevDiff = pointer.next() + if (pointer.hasNext()) { + thisDiff = pointer.next() } - if (both_types) { - // Factor out any common prefixies. - commonlength = diff_commonPrefix(text_insert, text_delete) - if (commonlength != 0) { - if (pointer.hasPrevious()) { - thisDiff = pointer.previous() - assert( - thisDiff.operation == Operation.EQUAL - ) { "Previous diff should have been an equality." } - thisDiff.text += text_insert!!.substring(0, commonlength) - pointer.next() + } + while (thisDiff != null) { + if (prevDiff!!.operation == Operation.DELETE && + thisDiff.operation == Operation.INSERT + ) { + val deletion = prevDiff.text + val insertion = thisDiff.text + val overlap_length1 = this.diff_commonOverlap(deletion, insertion) + val overlap_length2 = this.diff_commonOverlap(insertion, deletion) + if (overlap_length1 >= overlap_length2) { + if (overlap_length1 >= deletion!!.length / 2.0 || + overlap_length1 >= insertion!!.length / 2.0 + ) { + // Overlap found. Insert an equality and trim the surrounding edits. + pointer.previous() + pointer.add( + Diff( + Operation.EQUAL, + insertion!!.substring(0, overlap_length1) + ) + ) + prevDiff.text = + deletion.substring(0, deletion.length - overlap_length1) + thisDiff.text = insertion.substring(overlap_length1) + // pointer.add inserts the element before the cursor, so there is + // no need to step past the new element. + } } else { - pointer.add( - Diff( - Operation.EQUAL, - text_insert!!.substring(0, commonlength) - ) - ) + if (overlap_length2 >= deletion!!.length / 2.0 || + overlap_length2 >= insertion!!.length / 2.0 + ) { + // Reverse overlap found. + // Insert an equality and swap and trim the surrounding edits. + pointer.previous() + pointer.add( + Diff( + Operation.EQUAL, + deletion.substring(0, overlap_length2) + ) + ) + prevDiff.operation = Operation.INSERT + prevDiff.text = + insertion!!.substring(0, insertion.length - overlap_length2) + thisDiff.operation = Operation.DELETE + thisDiff.text = deletion.substring(overlap_length2) + // pointer.add inserts the element before the cursor, so there is + // no need to step past the new element. + } } - text_insert = text_insert.substring(commonlength) - text_delete = text_delete!!.substring(commonlength) - } - // Factor out any common suffixies. - commonlength = diff_commonSuffix(text_insert, text_delete) - if (commonlength != 0) { - thisDiff = pointer.next() - thisDiff.text = text_insert!!.substring( - text_insert.length - - commonlength - ) + thisDiff.text - text_insert = text_insert.substring( - 0, text_insert.length - - commonlength - ) - text_delete = text_delete!!.substring( - 0, text_delete.length - - commonlength - ) - pointer.previous() - } - } - // Insert the merged records. - if (text_delete!!.length != 0) { - pointer.add(Diff(Operation.DELETE, text_delete)) + thisDiff = if (pointer.hasNext()) pointer.next() else null } - if (text_insert!!.length != 0) { - pointer.add(Diff(Operation.INSERT, text_insert)) - } - // Step forward to the equality. + prevDiff = thisDiff thisDiff = if (pointer.hasNext()) pointer.next() else null - } else if (prevEqual != null) { - // Merge this equality with the previous one. - prevEqual.text += thisDiff.text - pointer.remove() - thisDiff = pointer.previous() - pointer.next() // Forward direction - } - count_insert = 0 - count_delete = 0 - text_delete = "" - text_insert = "" - prevEqual = thisDiff - } - - null -> TODO() - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - if (diffs.last.text!!.length == 0) { - diffs.removeLast() // Remove the dummy entry at the end. + } } - /* - * Second pass: look for single edits surrounded on both sides by equalities - * which can be shifted sideways to eliminate an equality. - * e.g: ABAC -> ABAC - */ - var changes = false - // Create a new iterator at the start. - // (As opposed to walking the current one back.) - pointer = diffs.listIterator() - var prevDiff = if (pointer.hasNext()) pointer.next() else null - thisDiff = if (pointer.hasNext()) pointer.next() else null - var nextDiff = if (pointer.hasNext()) pointer.next() else null - // Intentionally ignore the first and last element (don't need checking). - while (nextDiff != null) { - if (prevDiff!!.operation == Operation.EQUAL && - nextDiff.operation == Operation.EQUAL - ) { - // This is a single edit surrounded by equalities. - if (thisDiff!!.text!!.endsWith(prevDiff.text!!)) { - // Shift the edit over the previous equality. - thisDiff.text = (prevDiff.text - + thisDiff.text!!.substring( - 0, thisDiff.text!!.length - - prevDiff.text!!.length - )) - nextDiff.text = prevDiff.text + nextDiff.text - pointer.previous() // Walk past nextDiff. - pointer.previous() // Walk past thisDiff. - pointer.previous() // Walk past prevDiff. - pointer.remove() // Delete prevDiff. - pointer.next() // Walk past thisDiff. - thisDiff = pointer.next() // Walk past nextDiff. - nextDiff = if (pointer.hasNext()) pointer.next() else null - changes = true - } else if (thisDiff.text!!.startsWith(nextDiff.text!!)) { - // Shift the edit over the next equality. - prevDiff.text += nextDiff.text - thisDiff.text = (thisDiff.text!!.substring(nextDiff.text!!.length) - + nextDiff.text) - pointer.remove() // Delete nextDiff. - nextDiff = if (pointer.hasNext()) pointer.next() else null - changes = true - } - } - prevDiff = thisDiff - thisDiff = nextDiff - nextDiff = if (pointer.hasNext()) pointer.next() else null - } - // If shifts were made, the diff needs reordering and another shift sweep. - if (changes) { - diff_cleanupMerge(diffs) - } - } - - /** - * loc is a location in text1, compute and return the equivalent location in - * text2. - * e.g. "The cat" vs "The big cat", 1->1, 5->8 - * @param diffs List of Diff objects. - * @param loc Location within text1. - * @return Location within text2. - */ - private fun diff_xIndex(diffs: List, loc: Int): Int { - var chars1 = 0 - var chars2 = 0 - var last_chars1 = 0 - var last_chars2 = 0 - var lastDiff: Diff? = null - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.INSERT) { - // Equality or deletion. - chars1 += aDiff.text!!.length - } - if (aDiff.operation != Operation.DELETE) { - // Equality or insertion. - chars2 += aDiff.text!!.length - } - if (chars1 > loc) { - // Overshot the location. - lastDiff = aDiff - break - } - last_chars1 = chars1 - last_chars2 = chars2 + /** + * Look for single edits surrounded on both sides by equalities + * which can be shifted sideways to align the edit to a word boundary. + * e.g: The cat came. -> The cat came. + * @param diffs LinkedList of Diff objects. + */ + private fun diff_cleanupSemanticLossless(diffs: LinkedList) { + var equality1: String + var edit: String + var equality2: String + var commonString: String + var commonOffset: Int + var score: Int + var bestScore: Int + var bestEquality1: String? + var bestEdit: String? + var bestEquality2: String? + // Create a new iterator at the start. + val pointer = diffs.listIterator() + var prevDiff = if (pointer.hasNext()) pointer.next() else null + var thisDiff = if (pointer.hasNext()) pointer.next() else null + var nextDiff = if (pointer.hasNext()) pointer.next() else null + // Intentionally ignore the first and last element (don't need checking). + while (nextDiff != null) { + if (prevDiff!!.operation == Operation.EQUAL && + nextDiff.operation == Operation.EQUAL + ) { + // This is a single edit surrounded by equalities. + equality1 = prevDiff.text!! + edit = thisDiff!!.text!! + equality2 = nextDiff.text!! + + // First, shift the edit as far left as possible. + commonOffset = diff_commonSuffix(equality1, edit) + if (commonOffset != 0) { + commonString = edit.substring(edit.length - commonOffset) + equality1 = equality1.substring(0, equality1.length - commonOffset) + edit = commonString + edit.substring(0, edit.length - commonOffset) + equality2 = commonString + equality2 + } + + // Second, step character by character right, looking for the best fit. + bestEquality1 = equality1 + bestEdit = edit + bestEquality2 = equality2 + bestScore = (diff_cleanupSemanticScore(equality1, edit) + + diff_cleanupSemanticScore(edit, equality2)) + while (((edit.length != 0) && equality2.length != 0) && edit[0] == equality2[0]) { + equality1 += edit[0] + edit = edit.substring(1) + equality2[0] + equality2 = equality2.substring(1) + score = (diff_cleanupSemanticScore(equality1, edit) + + diff_cleanupSemanticScore(edit, equality2)) + // The >= encourages trailing rather than leading whitespace on edits. + if (score >= bestScore) { + bestScore = score + bestEquality1 = equality1 + bestEdit = edit + bestEquality2 = equality2 + } + } + + if (prevDiff.text != bestEquality1) { + // We have an improvement, save it back to the diff. + if (bestEquality1!!.length != 0) { + prevDiff.text = bestEquality1 + } else { + pointer.previous() // Walk past nextDiff. + pointer.previous() // Walk past thisDiff. + pointer.previous() // Walk past prevDiff. + pointer.remove() // Delete prevDiff. + pointer.next() // Walk past thisDiff. + pointer.next() // Walk past nextDiff. + } + thisDiff.text = bestEdit + if (bestEquality2!!.length != 0) { + nextDiff.text = bestEquality2 + } else { + pointer.remove() // Delete nextDiff. + nextDiff = thisDiff + thisDiff = prevDiff + } + } + } + prevDiff = thisDiff + thisDiff = nextDiff + nextDiff = if (pointer.hasNext()) pointer.next() else null + } } - if (lastDiff != null && lastDiff.operation == Operation.DELETE) { - // The location was deleted. - return last_chars2 + + /** + * Given two strings, compute a score representing whether the internal + * boundary falls on logical boundaries. + * Scores range from 6 (best) to 0 (worst). + * @param one First string. + * @param two Second string. + * @return The score. + */ + private fun diff_cleanupSemanticScore(one: String?, two: String?): Int { + if (one!!.length == 0 || two!!.length == 0) { + // Edges are the best. + return 6 + } + + // Each port of this function behaves slightly differently due to + // subtle differences in each language's definition of things like + // 'whitespace'. Since this function's purpose is largely cosmetic, + // the choice has been made to use each language's native features + // rather than force total conformity. + val char1 = one[one.length - 1] + val char2 = two[0] + val nonAlphaNumeric1 = !Character.isLetterOrDigit(char1) + val nonAlphaNumeric2 = !Character.isLetterOrDigit(char2) + val whitespace1 = nonAlphaNumeric1 && Character.isWhitespace(char1) + val whitespace2 = nonAlphaNumeric2 && Character.isWhitespace(char2) + val lineBreak1 = (whitespace1 + && Character.getType(char1) == Character.CONTROL.toInt()) + val lineBreak2 = (whitespace2 + && Character.getType(char2) == Character.CONTROL.toInt()) + val blankLine1 = lineBreak1 && BLANKLINEEND.matcher(one).find() + val blankLine2 = lineBreak2 && BLANKLINESTART.matcher(two).find() + + if (blankLine1 || blankLine2) { + // Five points for blank lines. + return 5 + } else if (lineBreak1 || lineBreak2) { + // Four points for line breaks. + return 4 + } else if (nonAlphaNumeric1 && !whitespace1 && whitespace2) { + // Three points for end of sentences. + return 3 + } else if (whitespace1 || whitespace2) { + // Two points for whitespace. + return 2 + } else if (nonAlphaNumeric1 || nonAlphaNumeric2) { + // One point for non-alphanumeric. + return 1 + } + return 0 } - // Add the remaining character length. - return last_chars2 + (loc - last_chars1) - } - - /** - * Compute and return the source text (all equalities and deletions). - * @param diffs List of Diff objects. - * @return Source text. - */ - private fun diff_text1(diffs: List): String { - val text = StringBuilder() - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.INSERT) { - text.append(aDiff.text) - } + + // Define some regex patterns for matching boundaries. + private val BLANKLINEEND + : Pattern = Pattern.compile("\\n\\r?\\n\\Z", Pattern.DOTALL) + private val BLANKLINESTART + : Pattern = Pattern.compile("\\A\\r?\\n\\r?\\n", Pattern.DOTALL) + + /** + * Reduce the number of edits by eliminating operationally trivial equalities. + * @param diffs LinkedList of Diff objects. + */ + fun diff_cleanupEfficiency(diffs: LinkedList) { + if (diffs.isEmpty()) { + return + } + var changes = false + val equalities = ArrayDeque() // Double-ended queue of equalities. + var lastEquality: String? = null // Always equal to equalities.peek().text + val pointer = diffs.listIterator() + // Is there an insertion operation before the last equality. + var pre_ins = false + // Is there a deletion operation before the last equality. + var pre_del = false + // Is there an insertion operation after the last equality. + var post_ins = false + // Is there a deletion operation after the last equality. + var post_del = false + var thisDiff: Diff? = pointer.next() + var safeDiff = thisDiff // The last Diff that is known to be unsplittable. + while (thisDiff != null) { + if (thisDiff.operation == Operation.EQUAL) { + // Equality found. + if (thisDiff.text!!.length < Diff_EditCost && (post_ins || post_del)) { + // Candidate found. + equalities.push(thisDiff) + pre_ins = post_ins + pre_del = post_del + lastEquality = thisDiff.text + } else { + // Not a candidate, and can never become one. + equalities.clear() + lastEquality = null + safeDiff = thisDiff + } + post_del = false + post_ins = post_del + } else { + // An insertion or deletion. + if (thisDiff.operation == Operation.DELETE) { + post_del = true + } else { + post_ins = true + } + /* + * Five types to be split: + * ABXYCD + * AXCD + * ABXC + * AXCD + * ABXC + */ + if (lastEquality != null + && ((pre_ins && pre_del && post_ins && post_del) + || ((lastEquality.length < Diff_EditCost / 2) + && ((if (pre_ins) 1 else 0) + (if (pre_del) 1 else 0) + + (if (post_ins) 1 else 0) + (if (post_del) 1 else 0)) == 3)) + ) { + //System.out.println("Splitting: '" + lastEquality + "'"); + // Walk back to offending equality. + while (thisDiff !== equalities.peek()) { + thisDiff = pointer.previous() + } + pointer.next() + + // Replace equality with a delete. + pointer.set(Diff(Operation.DELETE, lastEquality)) + // Insert a corresponding an insert. + pointer.add(Diff(Operation.INSERT, lastEquality).also { + thisDiff = it + }) + + equalities.pop() // Throw away the equality we just deleted. + lastEquality = null + if (pre_ins && pre_del) { + // No changes made which could affect previous entry, keep going. + post_del = true + post_ins = post_del + equalities.clear() + safeDiff = thisDiff + } else { + if (!equalities.isEmpty()) { + // Throw away the previous equality (it needs to be reevaluated). + equalities.pop() + } + if (equalities.isEmpty()) { + // There are no previous questionable equalities, + // walk back to the last known safe diff. + thisDiff = safeDiff + } else { + // There is an equality we can fall back to. + thisDiff = equalities.peek() + } + while (thisDiff !== pointer.previous()) { + // Intentionally empty loop. + } + post_del = false + post_ins = post_del + } + + changes = true + } + } + thisDiff = if (pointer.hasNext()) pointer.next() else null + } + + if (changes) { + diff_cleanupMerge(diffs) + } } - return text.toString() - } - - /** - * Compute and return the destination text (all equalities and insertions). - * @param diffs List of Diff objects. - * @return Destination text. - */ - private fun diff_text2(diffs: List): String { - val text = StringBuilder() - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.DELETE) { - text.append(aDiff.text) - } + + /** + * Reorder and merge like edit sections. Merge equalities. + * Any edit section can move as long as it doesn't cross an equality. + * @param diffs LinkedList of Diff objects. + */ + private fun diff_cleanupMerge(diffs: LinkedList) { + diffs.add(Diff(Operation.EQUAL, "")) // Add a dummy entry at the end. + var pointer = diffs.listIterator() + var count_delete = 0 + var count_insert = 0 + var text_delete: String? = "" + var text_insert: String? = "" + var thisDiff: Diff? = pointer.next() + var prevEqual: Diff? = null + var commonlength: Int + while (thisDiff != null) { + when (thisDiff.operation) { + Operation.INSERT -> { + count_insert++ + text_insert += thisDiff.text + prevEqual = null + } + + Operation.DELETE -> { + count_delete++ + text_delete += thisDiff.text + prevEqual = null + } + + Operation.EQUAL -> { + if (count_delete + count_insert > 1) { + val both_types = count_delete != 0 && count_insert != 0 + // Delete the offending records. + pointer.previous() // Reverse direction. + while (count_delete-- > 0) { + pointer.previous() + pointer.remove() + } + while (count_insert-- > 0) { + pointer.previous() + pointer.remove() + } + if (both_types) { + // Factor out any common prefixies. + commonlength = diff_commonPrefix(text_insert, text_delete) + if (commonlength != 0) { + if (pointer.hasPrevious()) { + thisDiff = pointer.previous() + assert( + thisDiff.operation == Operation.EQUAL + ) { "Previous diff should have been an equality." } + thisDiff.text += text_insert!!.substring(0, commonlength) + pointer.next() + } else { + pointer.add( + Diff( + Operation.EQUAL, + text_insert!!.substring(0, commonlength) + ) + ) + } + text_insert = text_insert.substring(commonlength) + text_delete = text_delete!!.substring(commonlength) + } + // Factor out any common suffixies. + commonlength = diff_commonSuffix(text_insert, text_delete) + if (commonlength != 0) { + thisDiff = pointer.next() + thisDiff.text = text_insert!!.substring( + text_insert.length + - commonlength + ) + thisDiff.text + text_insert = text_insert.substring( + 0, text_insert.length + - commonlength + ) + text_delete = text_delete!!.substring( + 0, text_delete.length + - commonlength + ) + pointer.previous() + } + } + // Insert the merged records. + if (text_delete!!.length != 0) { + pointer.add(Diff(Operation.DELETE, text_delete)) + } + if (text_insert!!.length != 0) { + pointer.add(Diff(Operation.INSERT, text_insert)) + } + // Step forward to the equality. + thisDiff = if (pointer.hasNext()) pointer.next() else null + } else if (prevEqual != null) { + // Merge this equality with the previous one. + prevEqual.text += thisDiff.text + pointer.remove() + thisDiff = pointer.previous() + pointer.next() // Forward direction + } + count_insert = 0 + count_delete = 0 + text_delete = "" + text_insert = "" + prevEqual = thisDiff + } + + null -> TODO() + } + thisDiff = if (pointer.hasNext()) pointer.next() else null + } + if (diffs.last.text!!.length == 0) { + diffs.removeLast() // Remove the dummy entry at the end. + } + + /* + * Second pass: look for single edits surrounded on both sides by equalities + * which can be shifted sideways to eliminate an equality. + * e.g: ABAC -> ABAC + */ + var changes = false + // Create a new iterator at the start. + // (As opposed to walking the current one back.) + pointer = diffs.listIterator() + var prevDiff = if (pointer.hasNext()) pointer.next() else null + thisDiff = if (pointer.hasNext()) pointer.next() else null + var nextDiff = if (pointer.hasNext()) pointer.next() else null + // Intentionally ignore the first and last element (don't need checking). + while (nextDiff != null) { + if (prevDiff!!.operation == Operation.EQUAL && + nextDiff.operation == Operation.EQUAL + ) { + // This is a single edit surrounded by equalities. + if (thisDiff!!.text!!.endsWith(prevDiff.text!!)) { + // Shift the edit over the previous equality. + thisDiff.text = (prevDiff.text + + thisDiff.text!!.substring( + 0, thisDiff.text!!.length + - prevDiff.text!!.length + )) + nextDiff.text = prevDiff.text + nextDiff.text + pointer.previous() // Walk past nextDiff. + pointer.previous() // Walk past thisDiff. + pointer.previous() // Walk past prevDiff. + pointer.remove() // Delete prevDiff. + pointer.next() // Walk past thisDiff. + thisDiff = pointer.next() // Walk past nextDiff. + nextDiff = if (pointer.hasNext()) pointer.next() else null + changes = true + } else if (thisDiff.text!!.startsWith(nextDiff.text!!)) { + // Shift the edit over the next equality. + prevDiff.text += nextDiff.text + thisDiff.text = (thisDiff.text!!.substring(nextDiff.text!!.length) + + nextDiff.text) + pointer.remove() // Delete nextDiff. + nextDiff = if (pointer.hasNext()) pointer.next() else null + changes = true + } + } + prevDiff = thisDiff + thisDiff = nextDiff + nextDiff = if (pointer.hasNext()) pointer.next() else null + } + // If shifts were made, the diff needs reordering and another shift sweep. + if (changes) { + diff_cleanupMerge(diffs) + } } - return text.toString() - } - - /** - * Compute the Levenshtein distance; the number of inserted, deleted or - * substituted characters. - * @param diffs List of Diff objects. - * @return Number of changes. - */ - private fun diff_levenshtein(diffs: List): Int { - var levenshtein = 0 - var insertions = 0 - var deletions = 0 - for (aDiff: Diff in diffs) { - when (aDiff.operation) { - Operation.INSERT -> insertions += aDiff.text!!.length - Operation.DELETE -> deletions += aDiff.text!!.length - Operation.EQUAL -> { - // A deletion and an insertion is one substitution. - levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() - insertions = 0 - deletions = 0 - } - - null -> TODO() - } + + /** + * loc is a location in text1, compute and return the equivalent location in + * text2. + * e.g. "The cat" vs "The big cat", 1->1, 5->8 + * @param diffs List of Diff objects. + * @param loc Location within text1. + * @return Location within text2. + */ + private fun diff_xIndex(diffs: List, loc: Int): Int { + var chars1 = 0 + var chars2 = 0 + var last_chars1 = 0 + var last_chars2 = 0 + var lastDiff: Diff? = null + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.INSERT) { + // Equality or deletion. + chars1 += aDiff.text!!.length + } + if (aDiff.operation != Operation.DELETE) { + // Equality or insertion. + chars2 += aDiff.text!!.length + } + if (chars1 > loc) { + // Overshot the location. + lastDiff = aDiff + break + } + last_chars1 = chars1 + last_chars2 = chars2 + } + if (lastDiff != null && lastDiff.operation == Operation.DELETE) { + // The location was deleted. + return last_chars2 + } + // Add the remaining character length. + return last_chars2 + (loc - last_chars1) } - levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() - return levenshtein - } - - - // MATCH FUNCTIONS - /** - * Locate the best instance of 'pattern' in 'text' near 'loc'. - * Returns -1 if no match found. - * @param text The text to search. - * @param pattern The pattern to search for. - * @param loc The location to search around. - * @return Best match index or -1. - */ - private fun match_main(text: String?, pattern: String?, loc: Int): Int { - // Check for null inputs. - var loc = loc - if (text == null || pattern == null) { - throw IllegalArgumentException("Null inputs. (match_main)") + + /** + * Compute and return the source text (all equalities and deletions). + * @param diffs List of Diff objects. + * @return Source text. + */ + private fun diff_text1(diffs: List): String { + val text = StringBuilder() + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.INSERT) { + text.append(aDiff.text) + } + } + return text.toString() } - loc = max(0.0, min(loc.toDouble(), text.length.toDouble())).toInt() - if ((text == pattern)) { - // Shortcut (potentially not guaranteed by the algorithm) - return 0 - } else if (text.length == 0) { - // Nothing to match. - return -1 - } else if ((loc + pattern.length <= text.length - && (text.substring(loc, loc + pattern.length) == pattern)) - ) { - // Perfect match at the perfect spot! (Includes case of null pattern) - return loc - } else { - // Do a fuzzy compare. - return match_bitap(text, pattern, loc) + /** + * Compute and return the destination text (all equalities and insertions). + * @param diffs List of Diff objects. + * @return Destination text. + */ + private fun diff_text2(diffs: List): String { + val text = StringBuilder() + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.DELETE) { + text.append(aDiff.text) + } + } + return text.toString() } - } - - /** - * Locate the best instance of 'pattern' in 'text' near 'loc' using the - * Bitap algorithm. Returns -1 if no match found. - * @param text The text to search. - * @param pattern The pattern to search for. - * @param loc The location to search around. - * @return Best match index or -1. - */ - private fun match_bitap(text: String, pattern: String, loc: Int): Int { - assert(Match_MaxBits.toInt() == 0 || pattern.length <= Match_MaxBits) { "Pattern too long for this application." } - // Initialise the alphabet. - val s = match_alphabet(pattern) - - // Highest score beyond which we give up. - var score_threshold = Match_Threshold.toDouble() - // Is there a nearby exact match? (speedup) - var best_loc = text.indexOf(pattern, loc) - if (best_loc != -1) { - score_threshold = min( - match_bitapScore(0, best_loc, loc, pattern), - score_threshold - ) - // What about in the other direction? (speedup) - best_loc = text.lastIndexOf(pattern, loc + pattern.length) - if (best_loc != -1) { - score_threshold = min( - match_bitapScore(0, best_loc, loc, pattern), - score_threshold - ) - } + + /** + * Compute the Levenshtein distance; the number of inserted, deleted or + * substituted characters. + * @param diffs List of Diff objects. + * @return Number of changes. + */ + private fun diff_levenshtein(diffs: List): Int { + var levenshtein = 0 + var insertions = 0 + var deletions = 0 + for (aDiff: Diff in diffs) { + when (aDiff.operation) { + Operation.INSERT -> insertions += aDiff.text!!.length + Operation.DELETE -> deletions += aDiff.text!!.length + Operation.EQUAL -> { + // A deletion and an insertion is one substitution. + levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() + insertions = 0 + deletions = 0 + } + + null -> TODO() + } + } + levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() + return levenshtein } - // Initialise the bit arrays. - val matchmask = 1 shl (pattern.length - 1) - best_loc = -1 - - var bin_min: Int - var bin_mid: Int - var bin_max = pattern.length + text.length - // Empty initialization added to appease Java compiler. - var last_rd = IntArray(0) - for (d in 0 until pattern.length) { - // Scan for the best match; each iteration allows for one more error. - // Run a binary search to determine how far from 'loc' we can stray at - // this error level. - bin_min = 0 - bin_mid = bin_max - while (bin_min < bin_mid) { - if ((match_bitapScore(d, loc + bin_mid, loc, pattern) - <= score_threshold) + + // MATCH FUNCTIONS + /** + * Locate the best instance of 'pattern' in 'text' near 'loc'. + * Returns -1 if no match found. + * @param text The text to search. + * @param pattern The pattern to search for. + * @param loc The location to search around. + * @return Best match index or -1. + */ + private fun match_main(text: String?, pattern: String?, loc: Int): Int { + // Check for null inputs. + var loc = loc + if (text == null || pattern == null) { + throw IllegalArgumentException("Null inputs. (match_main)") + } + + loc = max(0.0, min(loc.toDouble(), text.length.toDouble())).toInt() + if ((text == pattern)) { + // Shortcut (potentially not guaranteed by the algorithm) + return 0 + } else if (text.length == 0) { + // Nothing to match. + return -1 + } else if ((loc + pattern.length <= text.length + && (text.substring(loc, loc + pattern.length) == pattern)) ) { - bin_min = bin_mid + // Perfect match at the perfect spot! (Includes case of null pattern) + return loc } else { - bin_max = bin_mid - } - bin_mid = (bin_max - bin_min) / 2 + bin_min - } - // Use the result from this iteration as the maximum for the next. - bin_max = bin_mid - var start = max(1.0, (loc - bin_mid + 1).toDouble()).toInt() - val finish = (min((loc + bin_mid).toDouble(), text.length.toDouble()) + pattern.length).toInt() - - val rd = IntArray(finish + 2) - rd[finish + 1] = (1 shl d) - 1 - var j = finish - while (j >= start) { - var charMatch: Int - if (text.length <= j - 1 || !s.containsKey(text[j - 1])) { - // Out of range. - charMatch = 0 - } else { - charMatch = (s[text[j - 1]])!! + // Do a fuzzy compare. + return match_bitap(text, pattern, loc) } - if (d == 0) { - // First pass: exact match. - rd[j] = ((rd[j + 1] shl 1) or 1) and charMatch - } else { - // Subsequent passes: fuzzy match. - rd[j] = ((((rd[j + 1] shl 1) or 1) and charMatch) - or (((last_rd[j + 1] or last_rd[j]) shl 1) or 1) or last_rd[j + 1]) - } - if ((rd[j] and matchmask) != 0) { - val score = match_bitapScore(d, j - 1, loc, pattern) - // This match will almost certainly be better than any existing - // match. But check anyway. - if (score <= score_threshold) { - // Told you so. - score_threshold = score - best_loc = j - 1 - if (best_loc > loc) { - // When passing loc, don't exceed our current distance from loc. - start = max(1.0, (2 * loc - best_loc).toDouble()).toInt() - } else { - // Already passed loc, downhill from here on in. - break - } - } - } - j-- - } - if (match_bitapScore(d + 1, loc, loc, pattern) > score_threshold) { - // No hope for a (better) match at greater error levels. - break - } - last_rd = rd - } - return best_loc - } - - /** - * Compute and return the score for a match with e errors and x location. - * @param e Number of errors in match. - * @param x Location of match. - * @param loc Expected location of match. - * @param pattern Pattern being sought. - * @return Overall score for match (0.0 = good, 1.0 = bad). - */ - private fun match_bitapScore(e: Int, x: Int, loc: Int, pattern: String): Double { - val accuracy = e.toFloat() / pattern.length - val proximity = abs((loc - x).toDouble()).toInt() - if (Match_Distance == 0) { - // Dodge divide by zero error. - return if (proximity == 0) accuracy.toDouble() else 1.0 - } - return (accuracy + (proximity / Match_Distance.toFloat())).toDouble() - } - - /** - * Initialise the alphabet for the Bitap algorithm. - * @param pattern The text to encode. - * @return Hash of character locations. - */ - private fun match_alphabet(pattern: String): Map { - val s: MutableMap = HashMap() - val char_pattern = pattern.toCharArray() - for (c: Char in char_pattern) { - s[c] = 0 } - var i = 0 - for (c: Char in char_pattern) { - s[c] = s.get(c)!! or (1 shl (pattern.length - i - 1)) - i++ - } - return s - } - - - // PATCH FUNCTIONS - /** - * Increase the context until it is unique, - * but don't let the pattern expand beyond Match_MaxBits. - * @param patch The patch to grow. - * @param text Source text. - */ - private fun patch_addContext(patch: Patch, text: String) { - if (text.length == 0) { - return - } - var pattern = text.substring(patch.start2, patch.start2 + patch.length1) - var padding = 0 - // Look for the first and last matches of pattern in text. If two different - // matches are found, increase the pattern length. - while ((text.indexOf(pattern) != text.lastIndexOf(pattern) - && pattern.length < Match_MaxBits - Patch_Margin - Patch_Margin) - ) { - padding += Patch_Margin.toInt() - pattern = text.substring( - max(0.0, (patch.start2 - padding).toDouble()).toInt(), - min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() - ) - } - // Add one chunk for good luck. - padding += Patch_Margin.toInt() + /** + * Locate the best instance of 'pattern' in 'text' near 'loc' using the + * Bitap algorithm. Returns -1 if no match found. + * @param text The text to search. + * @param pattern The pattern to search for. + * @param loc The location to search around. + * @return Best match index or -1. + */ + private fun match_bitap(text: String, pattern: String, loc: Int): Int { + assert(Match_MaxBits.toInt() == 0 || pattern.length <= Match_MaxBits) { "Pattern too long for this application." } + // Initialise the alphabet. + val s = match_alphabet(pattern) + + // Highest score beyond which we give up. + var score_threshold = Match_Threshold.toDouble() + // Is there a nearby exact match? (speedup) + var best_loc = text.indexOf(pattern, loc) + if (best_loc != -1) { + score_threshold = min( + match_bitapScore(0, best_loc, loc, pattern), + score_threshold + ) + // What about in the other direction? (speedup) + best_loc = text.lastIndexOf(pattern, loc + pattern.length) + if (best_loc != -1) { + score_threshold = min( + match_bitapScore(0, best_loc, loc, pattern), + score_threshold + ) + } + } - // Add the prefix. - val prefix = text.substring( - max(0.0, (patch.start2 - padding).toDouble()).toInt(), - patch.start2 - ) - if (prefix.length != 0) { - patch.diffs.addFirst(Diff(Operation.EQUAL, prefix)) + // Initialise the bit arrays. + val matchmask = 1 shl (pattern.length - 1) + best_loc = -1 + + var bin_min: Int + var bin_mid: Int + var bin_max = pattern.length + text.length + // Empty initialization added to appease Java compiler. + var last_rd = IntArray(0) + for (d in 0 until pattern.length) { + // Scan for the best match; each iteration allows for one more error. + // Run a binary search to determine how far from 'loc' we can stray at + // this error level. + bin_min = 0 + bin_mid = bin_max + while (bin_min < bin_mid) { + if ((match_bitapScore(d, loc + bin_mid, loc, pattern) + <= score_threshold) + ) { + bin_min = bin_mid + } else { + bin_max = bin_mid + } + bin_mid = (bin_max - bin_min) / 2 + bin_min + } + // Use the result from this iteration as the maximum for the next. + bin_max = bin_mid + var start = max(1.0, (loc - bin_mid + 1).toDouble()).toInt() + val finish = (min((loc + bin_mid).toDouble(), text.length.toDouble()) + pattern.length).toInt() + + val rd = IntArray(finish + 2) + rd[finish + 1] = (1 shl d) - 1 + var j = finish + while (j >= start) { + var charMatch: Int + if (text.length <= j - 1 || !s.containsKey(text[j - 1])) { + // Out of range. + charMatch = 0 + } else { + charMatch = (s[text[j - 1]])!! + } + if (d == 0) { + // First pass: exact match. + rd[j] = ((rd[j + 1] shl 1) or 1) and charMatch + } else { + // Subsequent passes: fuzzy match. + rd[j] = ((((rd[j + 1] shl 1) or 1) and charMatch) + or (((last_rd[j + 1] or last_rd[j]) shl 1) or 1) or last_rd[j + 1]) + } + if ((rd[j] and matchmask) != 0) { + val score = match_bitapScore(d, j - 1, loc, pattern) + // This match will almost certainly be better than any existing + // match. But check anyway. + if (score <= score_threshold) { + // Told you so. + score_threshold = score + best_loc = j - 1 + if (best_loc > loc) { + // When passing loc, don't exceed our current distance from loc. + start = max(1.0, (2 * loc - best_loc).toDouble()).toInt() + } else { + // Already passed loc, downhill from here on in. + break + } + } + } + j-- + } + if (match_bitapScore(d + 1, loc, loc, pattern) > score_threshold) { + // No hope for a (better) match at greater error levels. + break + } + last_rd = rd + } + return best_loc } - // Add the suffix. - val suffix = text.substring( - patch.start2 + patch.length1, - min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() - ) - if (suffix.length != 0) { - patch.diffs.addLast(Diff(Operation.EQUAL, suffix)) + + /** + * Compute and return the score for a match with e errors and x location. + * @param e Number of errors in match. + * @param x Location of match. + * @param loc Expected location of match. + * @param pattern Pattern being sought. + * @return Overall score for match (0.0 = good, 1.0 = bad). + */ + private fun match_bitapScore(e: Int, x: Int, loc: Int, pattern: String): Double { + val accuracy = e.toFloat() / pattern.length + val proximity = abs((loc - x).toDouble()).toInt() + if (Match_Distance == 0) { + // Dodge divide by zero error. + return if (proximity == 0) accuracy.toDouble() else 1.0 + } + return (accuracy + (proximity / Match_Distance.toFloat())).toDouble() } - // Roll back the start points. - patch.start1 -= prefix.length - patch.start2 -= prefix.length - // Extend the lengths. - patch.length1 += prefix.length + suffix.length - patch.length2 += prefix.length + suffix.length - } - - /** - * Compute a list of patches to turn text1 into text2. - * A set of diffs will be computed. - * @param text1 Old text. - * @param text2 New text. - * @return LinkedList of Patch objects. - */ - fun patch_make(text1: String?, text2: String?): LinkedList { - if (text1 == null || text2 == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") + /** + * Initialise the alphabet for the Bitap algorithm. + * @param pattern The text to encode. + * @return Hash of character locations. + */ + private fun match_alphabet(pattern: String): Map { + val s: MutableMap = HashMap() + val char_pattern = pattern.toCharArray() + for (c: Char in char_pattern) { + s[c] = 0 + } + var i = 0 + for (c: Char in char_pattern) { + s[c] = s.get(c)!! or (1 shl (pattern.length - i - 1)) + i++ + } + return s } - // No diffs provided, compute our own. - val diffs = diff_main(text1, text2, true) - if (diffs.size > 2) { - diff_cleanupSemantic(diffs) - diff_cleanupEfficiency(diffs) + + + // PATCH FUNCTIONS + /** + * Increase the context until it is unique, + * but don't let the pattern expand beyond Match_MaxBits. + * @param patch The patch to grow. + * @param text Source text. + */ + private fun patch_addContext(patch: Patch, text: String) { + if (text.length == 0) { + return + } + var pattern = text.substring(patch.start2, patch.start2 + patch.length1) + var padding = 0 + + // Look for the first and last matches of pattern in text. If two different + // matches are found, increase the pattern length. + while ((text.indexOf(pattern) != text.lastIndexOf(pattern) + && pattern.length < Match_MaxBits - Patch_Margin - Patch_Margin) + ) { + padding += Patch_Margin.toInt() + pattern = text.substring( + max(0.0, (patch.start2 - padding).toDouble()).toInt(), + min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() + ) + } + // Add one chunk for good luck. + padding += Patch_Margin.toInt() + + // Add the prefix. + val prefix = text.substring( + max(0.0, (patch.start2 - padding).toDouble()).toInt(), + patch.start2 + ) + if (prefix.length != 0) { + patch.diffs.addFirst(Diff(Operation.EQUAL, prefix)) + } + // Add the suffix. + val suffix = text.substring( + patch.start2 + patch.length1, + min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() + ) + if (suffix.length != 0) { + patch.diffs.addLast(Diff(Operation.EQUAL, suffix)) + } + + // Roll back the start points. + patch.start1 -= prefix.length + patch.start2 -= prefix.length + // Extend the lengths. + patch.length1 += prefix.length + suffix.length + patch.length2 += prefix.length + suffix.length } - return patch_make(text1, diffs) - } - - /** - * Compute a list of patches to turn text1 into text2. - * text1 will be derived from the provided diffs. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - fun patch_make(diffs: LinkedList?): LinkedList { - if (diffs == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") + + /** + * Compute a list of patches to turn text1 into text2. + * A set of diffs will be computed. + * @param text1 Old text. + * @param text2 New text. + * @return LinkedList of Patch objects. + */ + fun patch_make(text1: String?, text2: String?): LinkedList { + if (text1 == null || text2 == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } + // No diffs provided, compute our own. + val diffs = diff_main(text1, text2, true) + if (diffs.size > 2) { + diff_cleanupSemantic(diffs) + diff_cleanupEfficiency(diffs) + } + return patch_make(text1, diffs) } - // No origin string provided, compute our own. - val text1 = diff_text1(diffs) - return patch_make(text1, diffs) - } - - /** - * Compute a list of patches to turn text1 into text2. - * text2 is ignored, diffs are the delta between text1 and text2. - * @param text1 Old text - * @param text2 Ignored. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - @Deprecated("Prefer patch_make(String text1, LinkedList diffs).") - fun patch_make( - text1: String?, text2: String?, - diffs: LinkedList? - ): LinkedList { - return patch_make(text1, diffs) - } - - /** - * Compute a list of patches to turn text1 into text2. - * text2 is not provided, diffs are the delta between text1 and text2. - * @param text1 Old text. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - fun patch_make(text1: String?, diffs: LinkedList?): LinkedList { - if (text1 == null || diffs == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") + + /** + * Compute a list of patches to turn text1 into text2. + * text1 will be derived from the provided diffs. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + fun patch_make(diffs: LinkedList?): LinkedList { + if (diffs == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } + // No origin string provided, compute our own. + val text1 = diff_text1(diffs) + return patch_make(text1, diffs) } - val patches = LinkedList() - if (diffs.isEmpty()) { - return patches // Get rid of the null case. + /** + * Compute a list of patches to turn text1 into text2. + * text2 is ignored, diffs are the delta between text1 and text2. + * @param text1 Old text + * @param text2 Ignored. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + @Deprecated("Prefer patch_make(String text1, LinkedList diffs).") + fun patch_make( + text1: String?, text2: String?, + diffs: LinkedList? + ): LinkedList { + return patch_make(text1, diffs) } - var patch = Patch() - var char_count1 = 0 // Number of characters into the text1 string. - var char_count2 = 0 // Number of characters into the text2 string. - // Start with text1 (prepatch_text) and apply the diffs until we arrive at - // text2 (postpatch_text). We recreate the patches one by one to determine - // context info. - var prepatch_text: String = text1 - var postpatch_text: String = text1 - for (aDiff: Diff in diffs) { - if (patch.diffs.isEmpty() && aDiff.operation != Operation.EQUAL) { - // A new patch starts here. - patch.start1 = char_count1 - patch.start2 = char_count2 - } - - when (aDiff.operation) { - Operation.INSERT -> { - patch.diffs.add(aDiff) - patch.length2 += aDiff.text!!.length - postpatch_text = (postpatch_text.substring(0, char_count2) - + aDiff.text + postpatch_text.substring(char_count2)) - } - - Operation.DELETE -> { - patch.length1 += aDiff.text!!.length - patch.diffs.add(aDiff) - postpatch_text = (postpatch_text.substring(0, char_count2) - + postpatch_text.substring(char_count2 + aDiff.text!!.length)) - } - - Operation.EQUAL -> { - if ((aDiff.text!!.length <= 2 * Patch_Margin - ) && !patch.diffs.isEmpty() && (aDiff !== diffs.last) - ) { - // Small equality inside a patch. - patch.diffs.add(aDiff) - patch.length1 += aDiff.text!!.length - patch.length2 += aDiff.text!!.length - } - - if (aDiff.text!!.length >= 2 * Patch_Margin && !patch.diffs.isEmpty()) { - // Time for a new patch. - if (!patch.diffs.isEmpty()) { - patch_addContext(patch, prepatch_text) - patches.add(patch) - patch = Patch() - // Unlike Unidiff, our patch lists have a rolling context. - // https://github.com/google/diff-match-patch/wiki/Unidiff - // Update prepatch text & pos to reflect the application of the - // just completed patch. - prepatch_text = postpatch_text - char_count1 = char_count2 + + /** + * Compute a list of patches to turn text1 into text2. + * text2 is not provided, diffs are the delta between text1 and text2. + * @param text1 Old text. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + fun patch_make(text1: String?, diffs: LinkedList?): LinkedList { + if (text1 == null || diffs == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } + + val patches = LinkedList() + if (diffs.isEmpty()) { + return patches // Get rid of the null case. + } + var patch = Patch() + var char_count1 = 0 // Number of characters into the text1 string. + var char_count2 = 0 // Number of characters into the text2 string. + // Start with text1 (prepatch_text) and apply the diffs until we arrive at + // text2 (postpatch_text). We recreate the patches one by one to determine + // context info. + var prepatch_text: String = text1 + var postpatch_text: String = text1 + for (aDiff: Diff in diffs) { + if (patch.diffs.isEmpty() && aDiff.operation != Operation.EQUAL) { + // A new patch starts here. + patch.start1 = char_count1 + patch.start2 = char_count2 } - } - } - - null -> TODO() - } - // Update the current character count. - if (aDiff.operation != Operation.INSERT) { - char_count1 += aDiff.text!!.length - } - if (aDiff.operation != Operation.DELETE) { - char_count2 += aDiff.text!!.length - } - } - // Pick up the leftover patch if not empty. - if (!patch.diffs.isEmpty()) { - patch_addContext(patch, prepatch_text) - patches.add(patch) - } - return patches - } - - /** - * Given an array of patches, return another array that is identical. - * @param patches Array of Patch objects. - * @return Array of Patch objects. - */ - private fun patch_deepCopy(patches: LinkedList): LinkedList { - val patchesCopy = LinkedList() - for (aPatch: Patch in patches) { - val patchCopy = Patch() - for (aDiff: Diff in aPatch.diffs) { - val diffCopy = Diff(aDiff.operation, aDiff.text) - patchCopy.diffs.add(diffCopy) - } - patchCopy.start1 = aPatch.start1 - patchCopy.start2 = aPatch.start2 - patchCopy.length1 = aPatch.length1 - patchCopy.length2 = aPatch.length2 - patchesCopy.add(patchCopy) - } - return patchesCopy - } - - /** - * Merge a set of patches onto the text. Return a patched text, as well - * as an array of true/false values indicating which patches were applied. - * @param patches Array of Patch objects - * @param text Old text. - * @return Two element Object array, containing the new text and an array of - * boolean values. - */ - fun patch_apply(patches: LinkedList, text: String): Array { - var patches = patches - var text = text - if (patches.isEmpty()) { - return arrayOf(text, BooleanArray(0)) - } + when (aDiff.operation) { + Operation.INSERT -> { + patch.diffs.add(aDiff) + patch.length2 += aDiff.text!!.length + postpatch_text = (postpatch_text.substring(0, char_count2) + + aDiff.text + postpatch_text.substring(char_count2)) + } - // Deep copy the patches so that no changes are made to originals. - patches = patch_deepCopy(patches) - - val nullPadding = patch_addPadding(patches) - text = nullPadding + text + nullPadding - patch_splitMax(patches) - - var x = 0 - // delta keeps track of the offset between the expected and actual location - // of the previous patch. If there are patches expected at positions 10 and - // 20, but the first patch was found at 12, delta is 2 and the second patch - // has an effective expected position of 22. - var delta = 0 - val results = BooleanArray(patches.size) - for (aPatch: Patch in patches) { - val expected_loc = aPatch.start2 + delta - val text1 = diff_text1(aPatch.diffs) - var start_loc: Int - var end_loc = -1 - if (text1.length > this.Match_MaxBits) { - // patch_splitMax will only provide an oversized pattern in the case of - // a monster delete. - start_loc = match_main( - text, - text1.substring(0, Match_MaxBits.toInt()), expected_loc - ) - if (start_loc != -1) { - end_loc = match_main( - text, - text1.substring(text1.length - this.Match_MaxBits), - expected_loc + text1.length - this.Match_MaxBits - ) - if (end_loc == -1 || start_loc >= end_loc) { - // Can't find valid trailing context. Drop this patch. - start_loc = -1 - } - } - } else { - start_loc = match_main(text, text1, expected_loc) - } - if (start_loc == -1) { - // No match found. :( - results[x] = false - // Subtract the delta for this failed patch from subsequent patches. - delta -= aPatch.length2 - aPatch.length1 - } else { - // Found a match. :) - results[x] = true - delta = start_loc - expected_loc - var text2: String - if (end_loc == -1) { - text2 = text.substring( - start_loc, - min((start_loc + text1.length).toDouble(), text.length.toDouble()).toInt() - ) - } else { - text2 = text.substring( - start_loc, - min((end_loc + this.Match_MaxBits).toDouble(), text.length.toDouble()).toInt() - ) - } - if ((text1 == text2)) { - // Perfect match, just shove the replacement text in. - text = (text.substring(0, start_loc) + diff_text2(aPatch.diffs) - + text.substring(start_loc + text1.length)) - } else { - // Imperfect match. Run a diff to get a framework of equivalent - // indices. - val diffs = diff_main(text1, text2, false) - if ((text1.length > this.Match_MaxBits - && diff_levenshtein(diffs) / text1.length.toFloat() - > this.Patch_DeleteThreshold) - ) { - // The end points match, but the content is unacceptably bad. - results[x] = false - } else { - diff_cleanupSemanticLossless(diffs) - var index1 = 0 - for (aDiff: Diff in aPatch.diffs) { - if (aDiff.operation != Operation.EQUAL) { - val index2 = diff_xIndex(diffs, index1) - if (aDiff.operation == Operation.INSERT) { - // Insertion - text = (text.substring(0, start_loc + index2) + aDiff.text - + text.substring(start_loc + index2)) - } else if (aDiff.operation == Operation.DELETE) { - // Deletion - text = (text.substring(0, start_loc + index2) - + text.substring( - start_loc + diff_xIndex( - diffs, - index1 + aDiff.text!!.length - ) - )) + Operation.DELETE -> { + patch.length1 += aDiff.text!!.length + patch.diffs.add(aDiff) + postpatch_text = (postpatch_text.substring(0, char_count2) + + postpatch_text.substring(char_count2 + aDiff.text!!.length)) } - } - if (aDiff.operation != Operation.DELETE) { - index1 += aDiff.text!!.length - } + + Operation.EQUAL -> { + if ((aDiff.text!!.length <= 2 * Patch_Margin + ) && !patch.diffs.isEmpty() && (aDiff !== diffs.last) + ) { + // Small equality inside a patch. + patch.diffs.add(aDiff) + patch.length1 += aDiff.text!!.length + patch.length2 += aDiff.text!!.length + } + + if (aDiff.text!!.length >= 2 * Patch_Margin && !patch.diffs.isEmpty()) { + // Time for a new patch. + if (!patch.diffs.isEmpty()) { + patch_addContext(patch, prepatch_text) + patches.add(patch) + patch = Patch() + // Unlike Unidiff, our patch lists have a rolling context. + // https://github.com/google/diff-match-patch/wiki/Unidiff + // Update prepatch text & pos to reflect the application of the + // just completed patch. + prepatch_text = postpatch_text + char_count1 = char_count2 + } + } + } + + null -> TODO() + } + // Update the current character count. + if (aDiff.operation != Operation.INSERT) { + char_count1 += aDiff.text!!.length + } + if (aDiff.operation != Operation.DELETE) { + char_count2 += aDiff.text!!.length } - } } - } - x++ - } - // Strip the padding off. - text = text.substring( - nullPadding.length, (text.length - - nullPadding.length) - ) - return arrayOf(text, results) - } - - /** - * Add some padding on text start and end so that edges can match something. - * Intended to be called only from within patch_apply. - * @param patches Array of Patch objects. - * @return The padding string added to each side. - */ - private fun patch_addPadding(patches: LinkedList): String { - val paddingLength = this.Patch_Margin - var nullPadding = "" - for (x in 1..paddingLength) { - nullPadding += (Char(x.toUShort())).toString() - } + // Pick up the leftover patch if not empty. + if (!patch.diffs.isEmpty()) { + patch_addContext(patch, prepatch_text) + patches.add(patch) + } - // Bump all the patches forward. - for (aPatch: Patch in patches) { - aPatch.start1 += paddingLength.toInt() - aPatch.start2 += paddingLength.toInt() + return patches } - // Add some padding on start of first diff. - var patch = patches.first - var diffs = patch.diffs - if (diffs.isEmpty() || diffs.first.operation != Operation.EQUAL) { - // Add nullPadding equality. - diffs.addFirst(Diff(Operation.EQUAL, nullPadding)) - patch.start1 -= paddingLength.toInt() // Should be 0. - patch.start2 -= paddingLength.toInt() // Should be 0. - patch.length1 += paddingLength.toInt() - patch.length2 += paddingLength.toInt() - } else if (paddingLength > diffs.first.text!!.length) { - // Grow first equality. - val firstDiff = diffs.first - val extraLength = paddingLength - firstDiff.text!!.length - firstDiff.text = (nullPadding.substring(firstDiff.text!!.length) - + firstDiff.text) - patch.start1 -= extraLength - patch.start2 -= extraLength - patch.length1 += extraLength - patch.length2 += extraLength + /** + * Given an array of patches, return another array that is identical. + * @param patches Array of Patch objects. + * @return Array of Patch objects. + */ + private fun patch_deepCopy(patches: LinkedList): LinkedList { + val patchesCopy = LinkedList() + for (aPatch: Patch in patches) { + val patchCopy = Patch() + for (aDiff: Diff in aPatch.diffs) { + val diffCopy = Diff(aDiff.operation, aDiff.text) + patchCopy.diffs.add(diffCopy) + } + patchCopy.start1 = aPatch.start1 + patchCopy.start2 = aPatch.start2 + patchCopy.length1 = aPatch.length1 + patchCopy.length2 = aPatch.length2 + patchesCopy.add(patchCopy) + } + return patchesCopy } - // Add some padding on end of last diff. - patch = patches.last - diffs = patch.diffs - if (diffs.isEmpty() || diffs.last.operation != Operation.EQUAL) { - // Add nullPadding equality. - diffs.addLast(Diff(Operation.EQUAL, nullPadding)) - patch.length1 += paddingLength.toInt() - patch.length2 += paddingLength.toInt() - } else if (paddingLength > diffs.last.text!!.length) { - // Grow last equality. - val lastDiff = diffs.last - val extraLength = paddingLength - lastDiff.text!!.length - lastDiff.text += nullPadding.substring(0, extraLength) - patch.length1 += extraLength - patch.length2 += extraLength - } + /** + * Merge a set of patches onto the text. Return a patched text, as well + * as an array of true/false values indicating which patches were applied. + * @param patches Array of Patch objects + * @param text Old text. + * @return Two element Object array, containing the new text and an array of + * boolean values. + */ + fun patch_apply(patches: LinkedList, text: String): Array { + var patches = patches + var text = text + if (patches.isEmpty()) { + return arrayOf(text, BooleanArray(0)) + } - return nullPadding - } - - /** - * Look through the patches and break up any which are longer than the - * maximum limit of the match algorithm. - * Intended to be called only from within patch_apply. - * @param patches LinkedList of Patch objects. - */ - private fun patch_splitMax(patches: LinkedList) { - val patch_size = Match_MaxBits - var precontext: String - var postcontext: String - var patch: Patch - var start1: Int - var start2: Int - var empty: Boolean - var diff_type: Operation - var diff_text: String - val pointer = patches.listIterator() - var bigpatch = if (pointer.hasNext()) pointer.next() else null - while (bigpatch != null) { - if (bigpatch.length1 <= Match_MaxBits) { - bigpatch = if (pointer.hasNext()) pointer.next() else null - continue - } - // Remove the big old patch. - pointer.remove() - start1 = bigpatch.start1 - start2 = bigpatch.start2 - precontext = "" - while (!bigpatch.diffs.isEmpty()) { - // Create one of several smaller patches. - patch = Patch() - empty = true - patch.start1 = start1 - precontext.length - patch.start2 = start2 - precontext.length - if (precontext.length != 0) { - patch.length2 = precontext.length - patch.length1 = patch.length2 - patch.diffs.add(Diff(Operation.EQUAL, precontext)) - } - while ((!bigpatch.diffs.isEmpty() - && patch.length1 < patch_size - Patch_Margin) - ) { - diff_type = bigpatch.diffs.first.operation!! - diff_text = bigpatch.diffs.first.text!! - if (diff_type == Operation.INSERT) { - // Insertions are harmless. - patch.length2 += diff_text.length - start2 += diff_text.length - patch.diffs.addLast(bigpatch.diffs.removeFirst()) - empty = false - } else if ((diff_type == Operation.DELETE) && (patch.diffs.size == 1 - ) && (patch.diffs.first.operation == Operation.EQUAL - ) && (diff_text.length > 2 * patch_size) - ) { - // This is a large deletion. Let it pass in one chunk. - patch.length1 += diff_text.length - start1 += diff_text.length - empty = false - patch.diffs.add(Diff(diff_type, diff_text)) - bigpatch.diffs.removeFirst() - } else { - // Deletion or equality. Only take as much as we can stomach. - diff_text = diff_text.substring( - 0, min( - diff_text.length.toDouble(), - (patch_size - patch.length1 - Patch_Margin).toDouble() - ).toInt() - ) - patch.length1 += diff_text.length - start1 += diff_text.length - if (diff_type == Operation.EQUAL) { - patch.length2 += diff_text.length - start2 += diff_text.length + // Deep copy the patches so that no changes are made to originals. + patches = patch_deepCopy(patches) + + val nullPadding = patch_addPadding(patches) + text = nullPadding + text + nullPadding + patch_splitMax(patches) + + var x = 0 + // delta keeps track of the offset between the expected and actual location + // of the previous patch. If there are patches expected at positions 10 and + // 20, but the first patch was found at 12, delta is 2 and the second patch + // has an effective expected position of 22. + var delta = 0 + val results = BooleanArray(patches.size) + for (aPatch: Patch in patches) { + val expected_loc = aPatch.start2 + delta + val text1 = diff_text1(aPatch.diffs) + var start_loc: Int + var end_loc = -1 + if (text1.length > this.Match_MaxBits) { + // patch_splitMax will only provide an oversized pattern in the case of + // a monster delete. + start_loc = match_main( + text, + text1.substring(0, Match_MaxBits.toInt()), expected_loc + ) + if (start_loc != -1) { + end_loc = match_main( + text, + text1.substring(text1.length - this.Match_MaxBits), + expected_loc + text1.length - this.Match_MaxBits + ) + if (end_loc == -1 || start_loc >= end_loc) { + // Can't find valid trailing context. Drop this patch. + start_loc = -1 + } + } } else { - empty = false + start_loc = match_main(text, text1, expected_loc) } - patch.diffs.add(Diff(diff_type, diff_text)) - if ((diff_text == bigpatch.diffs.first.text)) { - bigpatch.diffs.removeFirst() + if (start_loc == -1) { + // No match found. :( + results[x] = false + // Subtract the delta for this failed patch from subsequent patches. + delta -= aPatch.length2 - aPatch.length1 } else { - bigpatch.diffs.first.text = bigpatch.diffs.first.text!! - .substring(diff_text.length) + // Found a match. :) + results[x] = true + delta = start_loc - expected_loc + var text2: String + if (end_loc == -1) { + text2 = text.substring( + start_loc, + min((start_loc + text1.length).toDouble(), text.length.toDouble()).toInt() + ) + } else { + text2 = text.substring( + start_loc, + min((end_loc + this.Match_MaxBits).toDouble(), text.length.toDouble()).toInt() + ) + } + if ((text1 == text2)) { + // Perfect match, just shove the replacement text in. + text = (text.substring(0, start_loc) + diff_text2(aPatch.diffs) + + text.substring(start_loc + text1.length)) + } else { + // Imperfect match. Run a diff to get a framework of equivalent + // indices. + val diffs = diff_main(text1, text2, false) + if ((text1.length > this.Match_MaxBits + && diff_levenshtein(diffs) / text1.length.toFloat() + > this.Patch_DeleteThreshold) + ) { + // The end points match, but the content is unacceptably bad. + results[x] = false + } else { + diff_cleanupSemanticLossless(diffs) + var index1 = 0 + for (aDiff: Diff in aPatch.diffs) { + if (aDiff.operation != Operation.EQUAL) { + val index2 = diff_xIndex(diffs, index1) + if (aDiff.operation == Operation.INSERT) { + // Insertion + text = (text.substring(0, start_loc + index2) + aDiff.text + + text.substring(start_loc + index2)) + } else if (aDiff.operation == Operation.DELETE) { + // Deletion + text = (text.substring(0, start_loc + index2) + + text.substring( + start_loc + diff_xIndex( + diffs, + index1 + aDiff.text!!.length + ) + )) + } + } + if (aDiff.operation != Operation.DELETE) { + index1 += aDiff.text!!.length + } + } + } + } } - } - } - // Compute the head context for the next patch. - precontext = diff_text2(patch.diffs) - precontext = precontext.substring( - max( - 0.0, (precontext.length - - Patch_Margin).toDouble() - ).toInt() + x++ + } + // Strip the padding off. + text = text.substring( + nullPadding.length, (text.length + - nullPadding.length) ) - // Append the end context for this patch. - if (diff_text1(bigpatch.diffs).length > Patch_Margin) { - postcontext = diff_text1(bigpatch.diffs).substring(0, Patch_Margin.toInt()) - } else { - postcontext = diff_text1(bigpatch.diffs) - } - if (postcontext.length != 0) { - patch.length1 += postcontext.length - patch.length2 += postcontext.length - if ((!patch.diffs.isEmpty() - && patch.diffs.last.operation == Operation.EQUAL) - ) { - patch.diffs.last.text += postcontext - } else { - patch.diffs.add(Diff(Operation.EQUAL, postcontext)) - } - } - if (!empty) { - pointer.add(patch) - } - } - bigpatch = if (pointer.hasNext()) pointer.next() else null + return arrayOf(text, results) } - } - - /** - * Take a list of patches and return a textual representation. - * @param patches List of Patch objects. - * @return Text representation of patches. - */ - fun patch_toText(patches: List): String { - val text = StringBuilder() - for (aPatch: Patch? in patches) { - text.append(aPatch) - } - return text.toString() - } - - /** - * Parse a textual representation of patches and return a List of Patch - * objects. - * @param textline Text representation of patches. - * @return List of Patch objects. - * @throws IllegalArgumentException If invalid input. - */ - @Throws(IllegalArgumentException::class) - fun patch_fromText(textline: String): List { - val patches: MutableList = LinkedList() - if (textline.length == 0) { - return patches - } - val textList = Arrays.asList(*textline.split("\n".toRegex()).dropLastWhile { it.isEmpty() } - .toTypedArray()) - val text = LinkedList(textList) - var patch: Patch - val patchHeader = Pattern.compile("^@@ -(\\d+),?(\\d*) \\+(\\d+),?(\\d*) @@$") - var m: Matcher - var sign: Char - var line: String - while (!text.isEmpty()) { - m = patchHeader.matcher(text.first) - if (!m.matches()) { - throw IllegalArgumentException( - "Invalid patch string: " + text.first - ) - } - patch = Patch() - patches.add(patch) - patch.start1 = m.group(1).toInt() - if (m.group(2).length == 0) { - patch.start1-- - patch.length1 = 1 - } else if ((m.group(2) == "0")) { - patch.length1 = 0 - } else { - patch.start1-- - patch.length1 = m.group(2).toInt() - } - - patch.start2 = m.group(3).toInt() - if (m.group(4).length == 0) { - patch.start2-- - patch.length2 = 1 - } else if ((m.group(4) == "0")) { - patch.length2 = 0 - } else { - patch.start2-- - patch.length2 = m.group(4).toInt() - } - text.removeFirst() - - while (!text.isEmpty()) { - try { - sign = text.first[0] - } catch (e: IndexOutOfBoundsException) { - // Blank line? Whatever. - text.removeFirst() - continue - } - line = text.first.substring(1) - line = line.replace("+", "%2B") // decode would change all "+" to " " - try { - line = URLDecoder.decode(line, "UTF-8") - } catch (e: UnsupportedEncodingException) { - // Not likely on modern system. - throw Error("This system does not support UTF-8.", e) - } catch (e: IllegalArgumentException) { - // Malformed URI sequence. - throw IllegalArgumentException( - "Illegal escape in patch_fromText: $line", e - ) - } - if (sign == '-') { - // Deletion. - patch.diffs.add(Diff(Operation.DELETE, line)) - } else if (sign == '+') { - // Insertion. - patch.diffs.add(Diff(Operation.INSERT, line)) - } else if (sign == ' ') { - // Minor equality. - patch.diffs.add(Diff(Operation.EQUAL, line)) - } else if (sign == '@') { - // Start of next patch. - break - } else { - // WTF? - throw IllegalArgumentException( - "Invalid patch mode '$sign' in: $line" - ) + + /** + * Add some padding on text start and end so that edges can match something. + * Intended to be called only from within patch_apply. + * @param patches Array of Patch objects. + * @return The padding string added to each side. + */ + private fun patch_addPadding(patches: LinkedList): String { + val paddingLength = this.Patch_Margin + var nullPadding = "" + for (x in 1..paddingLength) { + nullPadding += (Char(x.toUShort())).toString() } - text.removeFirst() - } + + // Bump all the patches forward. + for (aPatch: Patch in patches) { + aPatch.start1 += paddingLength.toInt() + aPatch.start2 += paddingLength.toInt() + } + + // Add some padding on start of first diff. + var patch = patches.first + var diffs = patch.diffs + if (diffs.isEmpty() || diffs.first.operation != Operation.EQUAL) { + // Add nullPadding equality. + diffs.addFirst(Diff(Operation.EQUAL, nullPadding)) + patch.start1 -= paddingLength.toInt() // Should be 0. + patch.start2 -= paddingLength.toInt() // Should be 0. + patch.length1 += paddingLength.toInt() + patch.length2 += paddingLength.toInt() + } else if (paddingLength > diffs.first.text!!.length) { + // Grow first equality. + val firstDiff = diffs.first + val extraLength = paddingLength - firstDiff.text!!.length + firstDiff.text = (nullPadding.substring(firstDiff.text!!.length) + + firstDiff.text) + patch.start1 -= extraLength + patch.start2 -= extraLength + patch.length1 += extraLength + patch.length2 += extraLength + } + + // Add some padding on end of last diff. + patch = patches.last + diffs = patch.diffs + if (diffs.isEmpty() || diffs.last.operation != Operation.EQUAL) { + // Add nullPadding equality. + diffs.addLast(Diff(Operation.EQUAL, nullPadding)) + patch.length1 += paddingLength.toInt() + patch.length2 += paddingLength.toInt() + } else if (paddingLength > diffs.last.text!!.length) { + // Grow last equality. + val lastDiff = diffs.last + val extraLength = paddingLength - lastDiff.text!!.length + lastDiff.text += nullPadding.substring(0, extraLength) + patch.length1 += extraLength + patch.length2 += extraLength + } + + return nullPadding } - return patches - } - - - /** - * Class representing one diff operation. - */ - class Diff// Construct a diff with the specified operation and text. - /** - * Constructor. Initializes the diff with the provided values. - * @param operation One of INSERT, DELETE or EQUAL. - * @param text The text being applied. - */( - /** - * One of: INSERT, DELETE or EQUAL. - */ - var operation: Operation?, - /** - * The text associated with this diff operation. - */ - var text: String? - ) { + /** - * Display a human-readable version of this Diff. - * @return text version. + * Look through the patches and break up any which are longer than the + * maximum limit of the match algorithm. + * Intended to be called only from within patch_apply. + * @param patches LinkedList of Patch objects. */ - override fun toString(): String { - val prettyText = text!!.replace('\n', '\u00b6') - return "Diff(" + this.operation + ",\"" + prettyText + "\")" + private fun patch_splitMax(patches: LinkedList) { + val patch_size = Match_MaxBits + var precontext: String + var postcontext: String + var patch: Patch + var start1: Int + var start2: Int + var empty: Boolean + var diff_type: Operation + var diff_text: String + val pointer = patches.listIterator() + var bigpatch = if (pointer.hasNext()) pointer.next() else null + while (bigpatch != null) { + if (bigpatch.length1 <= Match_MaxBits) { + bigpatch = if (pointer.hasNext()) pointer.next() else null + continue + } + // Remove the big old patch. + pointer.remove() + start1 = bigpatch.start1 + start2 = bigpatch.start2 + precontext = "" + while (!bigpatch.diffs.isEmpty()) { + // Create one of several smaller patches. + patch = Patch() + empty = true + patch.start1 = start1 - precontext.length + patch.start2 = start2 - precontext.length + if (precontext.length != 0) { + patch.length2 = precontext.length + patch.length1 = patch.length2 + patch.diffs.add(Diff(Operation.EQUAL, precontext)) + } + while ((!bigpatch.diffs.isEmpty() + && patch.length1 < patch_size - Patch_Margin) + ) { + diff_type = bigpatch.diffs.first.operation!! + diff_text = bigpatch.diffs.first.text!! + if (diff_type == Operation.INSERT) { + // Insertions are harmless. + patch.length2 += diff_text.length + start2 += diff_text.length + patch.diffs.addLast(bigpatch.diffs.removeFirst()) + empty = false + } else if ((diff_type == Operation.DELETE) && (patch.diffs.size == 1 + ) && (patch.diffs.first.operation == Operation.EQUAL + ) && (diff_text.length > 2 * patch_size) + ) { + // This is a large deletion. Let it pass in one chunk. + patch.length1 += diff_text.length + start1 += diff_text.length + empty = false + patch.diffs.add(Diff(diff_type, diff_text)) + bigpatch.diffs.removeFirst() + } else { + // Deletion or equality. Only take as much as we can stomach. + diff_text = diff_text.substring( + 0, min( + diff_text.length.toDouble(), + (patch_size - patch.length1 - Patch_Margin).toDouble() + ).toInt() + ) + patch.length1 += diff_text.length + start1 += diff_text.length + if (diff_type == Operation.EQUAL) { + patch.length2 += diff_text.length + start2 += diff_text.length + } else { + empty = false + } + patch.diffs.add(Diff(diff_type, diff_text)) + if ((diff_text == bigpatch.diffs.first.text)) { + bigpatch.diffs.removeFirst() + } else { + bigpatch.diffs.first.text = bigpatch.diffs.first.text!! + .substring(diff_text.length) + } + } + } + // Compute the head context for the next patch. + precontext = diff_text2(patch.diffs) + precontext = precontext.substring( + max( + 0.0, (precontext.length + - Patch_Margin).toDouble() + ).toInt() + ) + // Append the end context for this patch. + if (diff_text1(bigpatch.diffs).length > Patch_Margin) { + postcontext = diff_text1(bigpatch.diffs).substring(0, Patch_Margin.toInt()) + } else { + postcontext = diff_text1(bigpatch.diffs) + } + if (postcontext.length != 0) { + patch.length1 += postcontext.length + patch.length2 += postcontext.length + if ((!patch.diffs.isEmpty() + && patch.diffs.last.operation == Operation.EQUAL) + ) { + patch.diffs.last.text += postcontext + } else { + patch.diffs.add(Diff(Operation.EQUAL, postcontext)) + } + } + if (!empty) { + pointer.add(patch) + } + } + bigpatch = if (pointer.hasNext()) pointer.next() else null + } } /** - * Create a numeric hash value for a Diff. - * This function is not used by DMP. - * @return Hash value. + * Take a list of patches and return a textual representation. + * @param patches List of Patch objects. + * @return Text representation of patches. */ - override fun hashCode(): Int { - val prime = 31 - var result = if ((operation == null)) 0 else operation.hashCode() - result += prime * (if ((text == null)) 0 else text.hashCode()) - return result + fun patch_toText(patches: List): String { + val text = StringBuilder() + for (aPatch: Patch? in patches) { + text.append(aPatch) + } + return text.toString() } /** - * Is this Diff equivalent to another Diff? - * @param obj Another Diff to compare against. - * @return true or false. + * Parse a textual representation of patches and return a List of Patch + * objects. + * @param textline Text representation of patches. + * @return List of Patch objects. + * @throws IllegalArgumentException If invalid input. */ - override fun equals(obj: Any?): Boolean { - if (this === obj) { - return true - } - if (obj == null) { - return false - } - if (javaClass != obj.javaClass) { - return false - } - val other = obj as Diff - if (operation != other.operation) { - return false - } - if (text == null) { - if (other.text != null) { - return false - } - } else if (text != other.text) { - return false - } - return true - } - } + @Throws(IllegalArgumentException::class) + fun patch_fromText(textline: String): List { + val patches: MutableList = LinkedList() + if (textline.length == 0) { + return patches + } + val textList = Arrays.asList(*textline.split("\n".toRegex()).dropLastWhile { it.isEmpty() } + .toTypedArray()) + val text = LinkedList(textList) + var patch: Patch + val patchHeader = Pattern.compile("^@@ -(\\d+),?(\\d*) \\+(\\d+),?(\\d*) @@$") + var m: Matcher + var sign: Char + var line: String + while (!text.isEmpty()) { + m = patchHeader.matcher(text.first) + if (!m.matches()) { + throw IllegalArgumentException( + "Invalid patch string: " + text.first + ) + } + patch = Patch() + patches.add(patch) + patch.start1 = m.group(1).toInt() + if (m.group(2).length == 0) { + patch.start1-- + patch.length1 = 1 + } else if ((m.group(2) == "0")) { + patch.length1 = 0 + } else { + patch.start1-- + patch.length1 = m.group(2).toInt() + } + patch.start2 = m.group(3).toInt() + if (m.group(4).length == 0) { + patch.start2-- + patch.length2 = 1 + } else if ((m.group(4) == "0")) { + patch.length2 = 0 + } else { + patch.start2-- + patch.length2 = m.group(4).toInt() + } + text.removeFirst() + + while (!text.isEmpty()) { + try { + sign = text.first[0] + } catch (e: IndexOutOfBoundsException) { + // Blank line? Whatever. + text.removeFirst() + continue + } + line = text.first.substring(1) + line = line.replace("+", "%2B") // decode would change all "+" to " " + try { + line = URLDecoder.decode(line, "UTF-8") + } catch (e: UnsupportedEncodingException) { + // Not likely on modern system. + throw Error("This system does not support UTF-8.", e) + } catch (e: IllegalArgumentException) { + // Malformed URI sequence. + throw IllegalArgumentException( + "Illegal escape in patch_fromText: $line", e + ) + } + if (sign == '-') { + // Deletion. + patch.diffs.add(Diff(Operation.DELETE, line)) + } else if (sign == '+') { + // Insertion. + patch.diffs.add(Diff(Operation.INSERT, line)) + } else if (sign == ' ') { + // Minor equality. + patch.diffs.add(Diff(Operation.EQUAL, line)) + } else if (sign == '@') { + // Start of next patch. + break + } else { + // WTF? + throw IllegalArgumentException( + "Invalid patch mode '$sign' in: $line" + ) + } + text.removeFirst() + } + } + return patches + } - /** - * Class representing one patch operation. - */ - class Patch() { - var diffs: LinkedList - var start1: Int = 0 - var start2: Int = 0 - var length1: Int = 0 - var length2: Int = 0 /** - * Constructor. Initializes with an empty list of diffs. + * Class representing one diff operation. */ - init { - this.diffs = LinkedList() + class Diff// Construct a diff with the specified operation and text. + /** + * Constructor. Initializes the diff with the provided values. + * @param operation One of INSERT, DELETE or EQUAL. + * @param text The text being applied. + */( + /** + * One of: INSERT, DELETE or EQUAL. + */ + var operation: Operation?, + /** + * The text associated with this diff operation. + */ + var text: String? + ) { + /** + * Display a human-readable version of this Diff. + * @return text version. + */ + override fun toString(): String { + val prettyText = text!!.replace('\n', '\u00b6') + return "Diff(" + this.operation + ",\"" + prettyText + "\")" + } + + /** + * Create a numeric hash value for a Diff. + * This function is not used by DMP. + * @return Hash value. + */ + override fun hashCode(): Int { + val prime = 31 + var result = if ((operation == null)) 0 else operation.hashCode() + result += prime * (if ((text == null)) 0 else text.hashCode()) + return result + } + + /** + * Is this Diff equivalent to another Diff? + * @param obj Another Diff to compare against. + * @return true or false. + */ + override fun equals(obj: Any?): Boolean { + if (this === obj) { + return true + } + if (obj == null) { + return false + } + if (javaClass != obj.javaClass) { + return false + } + val other = obj as Diff + if (operation != other.operation) { + return false + } + if (text == null) { + if (other.text != null) { + return false + } + } else if (text != other.text) { + return false + } + return true + } } + /** - * Emulate GNU diff's format. - * Header: @@ -382,8 +481,9 @@ - * Indices are printed as 1-based, not 0-based. - * @return The GNU diff string. + * Class representing one patch operation. */ - override fun toString(): String { - val coords1: String - val coords2: String - if (this.length1 == 0) { - coords1 = start1.toString() + ",0" - } else if (this.length1 == 1) { - coords1 = (this.start1 + 1).toString() - } else { - coords1 = (this.start1 + 1).toString() + "," + this.length1 - } - if (this.length2 == 0) { - coords2 = start2.toString() + ",0" - } else if (this.length2 == 1) { - coords2 = (this.start2 + 1).toString() - } else { - coords2 = (this.start2 + 1).toString() + "," + this.length2 - } - val text = StringBuilder() - text.append("@@ -").append(coords1).append(" +").append(coords2) - .append(" @@\n") - // Escape the body of the patch with %xx notation. - for (aDiff: Diff in this.diffs) { - when (aDiff.operation) { - Operation.INSERT -> text.append('+') - Operation.DELETE -> text.append('-') - Operation.EQUAL -> text.append(' ') - null -> TODO() - } - try { - text.append(URLEncoder.encode(aDiff.text, "UTF-8").replace('+', ' ')) - .append("\n") - } catch (e: UnsupportedEncodingException) { - // Not likely on modern system. - throw Error("This system does not support UTF-8.", e) - } - } - return unescapeForEncodeUriCompatability(text.toString()) + class Patch { + var diffs: LinkedList + var start1: Int = 0 + var start2: Int = 0 + var length1: Int = 0 + var length2: Int = 0 + + /** + * Constructor. Initializes with an empty list of diffs. + */ + init { + this.diffs = LinkedList() + } + + /** + * Emulate GNU diff's format. + * Header: @@ -382,8 +481,9 @@ + * Indices are printed as 1-based, not 0-based. + * @return The GNU diff string. + */ + override fun toString(): String { + val coords1: String + val coords2: String + if (this.length1 == 0) { + coords1 = start1.toString() + ",0" + } else if (this.length1 == 1) { + coords1 = (this.start1 + 1).toString() + } else { + coords1 = (this.start1 + 1).toString() + "," + this.length1 + } + if (this.length2 == 0) { + coords2 = start2.toString() + ",0" + } else if (this.length2 == 1) { + coords2 = (this.start2 + 1).toString() + } else { + coords2 = (this.start2 + 1).toString() + "," + this.length2 + } + val text = StringBuilder() + text.append("@@ -").append(coords1).append(" +").append(coords2) + .append(" @@\n") + // Escape the body of the patch with %xx notation. + for (aDiff: Diff in this.diffs) { + when (aDiff.operation) { + Operation.INSERT -> text.append('+') + Operation.DELETE -> text.append('-') + Operation.EQUAL -> text.append(' ') + null -> TODO() + } + try { + text.append(URLEncoder.encode(aDiff.text, "UTF-8").replace('+', ' ')) + .append("\n") + } catch (e: UnsupportedEncodingException) { + // Not likely on modern system. + throw Error("This system does not support UTF-8.", e) + } + } + return unescapeForEncodeUriCompatability(text.toString()) + } } - } - companion object : DiffMatchPatch() { - /** - * Unescape selected chars for compatability with JavaScript's encodeURI. - * In speed critical applications this could be dropped since the - * receiving application will certainly decode these fine. - * Note that this function is case-sensitive. Thus "%3f" would not be - * unescaped. But this is ok because it is only called with the output of - * URLEncoder.encode which returns uppercase hex. - * - * Example: "%3F" -> "?", "%24" -> "$", etc. - * - * @param str The string to escape. - * @return The escaped string. - */ - private fun unescapeForEncodeUriCompatability(str: String): String { - return str.replace("%21", "!").replace("%7E", "~") - .replace("%27", "'").replace("%28", "(").replace("%29", ")") - .replace("%3B", ";").replace("%2F", "/").replace("%3F", "?") - .replace("%3A", ":").replace("%40", "@").replace("%26", "&") - .replace("%3D", "=").replace("%2B", "+").replace("%24", "$") - .replace("%2C", ",").replace("%23", "#") + companion object : DiffMatchPatch() { + /** + * Unescape selected chars for compatability with JavaScript's encodeURI. + * In speed critical applications this could be dropped since the + * receiving application will certainly decode these fine. + * Note that this function is case-sensitive. Thus "%3f" would not be + * unescaped. But this is ok because it is only called with the output of + * URLEncoder.encode which returns uppercase hex. + * + * Example: "%3F" -> "?", "%24" -> "$", etc. + * + * @param str The string to escape. + * @return The escaped string. + */ + private fun unescapeForEncodeUriCompatability(str: String): String { + return str.replace("%21", "!").replace("%7E", "~") + .replace("%27", "'").replace("%28", "(").replace("%29", ")") + .replace("%3B", ";").replace("%2F", "/").replace("%3F", "?") + .replace("%3A", ":").replace("%40", "@").replace("%26", "&") + .replace("%3D", "=").replace("%2B", "+").replace("%24", "$") + .replace("%2C", ",").replace("%23", "#") + } } - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffUtil.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffUtil.kt index 97791bae..031a7abb 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffUtil.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/DiffUtil.kt @@ -3,127 +3,128 @@ package com.github.simiacryptus.diff import com.github.simiacryptus.diff.PatchLineType.* enum class PatchLineType { - Added, Deleted, Unchanged + Added, Deleted, Unchanged } data class PatchLine( - val type: PatchLineType, - val lineNumber: Int, - val line: String, - val compareText: String = line.trim(), + val type: PatchLineType, + val lineNumber: Int, + val line: String, + val compareText: String = line.trim(), ) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false - other as PatchLine + other as PatchLine - return compareText == other.compareText - } + return compareText == other.compareText + } - override fun hashCode(): Int { - return compareText.hashCode() - } + override fun hashCode(): Int { + return compareText.hashCode() + } } object DiffUtil { - /** - * Generates a list of DiffResult representing the differences between two lists of strings. - * This function compares the original and modified texts line by line and categorizes each line as added, deleted, or unchanged. - * - * @param original The original list of strings. - * @param modified The modified list of strings. - * @return A list of DiffResult indicating the differences. - */ - fun generateDiff(original: List, modified: List): List { - if (original == modified) return modified.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) } - val original = original.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) }.toMutableList() - val modified = modified.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) }.toMutableList() - val patchLines = mutableListOf() + /** + * Generates a list of DiffResult representing the differences between two lists of strings. + * This function compares the original and modified texts line by line and categorizes each line as added, deleted, or unchanged. + * + * @param original The original list of strings. + * @param modified The modified list of strings. + * @return A list of DiffResult indicating the differences. + */ + fun generateDiff(original: List, modified: List): List { + if (original == modified) return modified.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) } + val original = original.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) }.toMutableList() + val modified = modified.withIndex().map { (i, v) -> PatchLine(Unchanged, i, v) }.toMutableList() + val patchLines = mutableListOf() - while (original.isNotEmpty() && modified.isNotEmpty()) { - val originalLine = original.first() - val modifiedLine = modified.first() + while (original.isNotEmpty() && modified.isNotEmpty()) { + val originalLine = original.first() + val modifiedLine = modified.first() - if (originalLine == modifiedLine) { - patchLines.add(PatchLine(Unchanged, originalLine.lineNumber, originalLine.line)) - original.removeAt(0) - modified.removeAt(0) - } else { - val originalIndex = original.indexOf(modifiedLine).let { if (it == -1) null else it } - val modifiedIndex = modified.indexOf(originalLine).let { if (it == -1) null else it } + if (originalLine == modifiedLine) { + patchLines.add(PatchLine(Unchanged, originalLine.lineNumber, originalLine.line)) + original.removeAt(0) + modified.removeAt(0) + } else { + val originalIndex = original.indexOf(modifiedLine).let { if (it == -1) null else it } + val modifiedIndex = modified.indexOf(originalLine).let { if (it == -1) null else it } - if (originalIndex != null && modifiedIndex != null) { - if (originalIndex < modifiedIndex) { - while(original.first() != modifiedLine) { - patchLines.add(PatchLine(Deleted, original.first().lineNumber, original.first().line)) - original.removeAt(0) + if (originalIndex != null && modifiedIndex != null) { + if (originalIndex < modifiedIndex) { + while (original.first() != modifiedLine) { + patchLines.add(PatchLine(Deleted, original.first().lineNumber, original.first().line)) + original.removeAt(0) + } + } else { + while (modified.first() != originalLine) { + patchLines.add(PatchLine(Added, modified.first().lineNumber, modified.first().line)) + modified.removeAt(0) + } + } + } else if (originalIndex != null) { + while (original.first() != modifiedLine) { + patchLines.add(PatchLine(Deleted, original.first().lineNumber, original.first().line)) + original.removeAt(0) + } + } else if (modifiedIndex != null) { + while (modified.first() != originalLine) { + patchLines.add(PatchLine(Added, modified.first().lineNumber, modified.first().line)) + modified.removeAt(0) + } + } else { + patchLines.add(PatchLine(Deleted, originalLine.lineNumber, originalLine.line)) + original.removeAt(0) + patchLines.add(PatchLine(Added, modifiedLine.lineNumber, modifiedLine.line)) + modified.removeAt(0) + } } - } else { - while(modified.first() != originalLine) { - patchLines.add(PatchLine(Added, modified.first().lineNumber, modified.first().line)) - modified.removeAt(0) - } - } - } else if (originalIndex != null) { - while(original.first() != modifiedLine) { - patchLines.add(PatchLine(Deleted, original.first().lineNumber, original.first().line)) - original.removeAt(0) - } - } else if (modifiedIndex != null) { - while(modified.first() != originalLine) { - patchLines.add(PatchLine(Added, modified.first().lineNumber, modified.first().line)) - modified.removeAt(0) - } - } else { - patchLines.add(PatchLine(Deleted, originalLine.lineNumber, originalLine.line)) - original.removeAt(0) - patchLines.add(PatchLine(Added, modifiedLine.lineNumber, modifiedLine.line)) - modified.removeAt(0) } - } + patchLines.addAll(original.map { PatchLine(Deleted, it.lineNumber, it.line) }) + patchLines.addAll(modified.map { PatchLine(Added, it.lineNumber, it.line) }) + return patchLines } - patchLines.addAll(original.map { PatchLine(Deleted, it.lineNumber, it.line) }) - patchLines.addAll(modified.map { PatchLine(Added, it.lineNumber, it.line) }) - return patchLines - } - /** - * Formats the list of DiffResult into a human-readable string representation. - * This function processes each diff result to format added, deleted, and unchanged lines appropriately, - * including context lines and markers for easier reading. - * - * @param patchLines The list of DiffResult to format. - * @param contextLines The number of context lines to include around changes. - * @return A formatted string representing the diff. - */ - fun formatDiff(patchLines: List, contextLines: Int = 3): String { - val patchList = patchLines.withIndex().filter { (idx, lineDiff) -> - when (lineDiff.type) { - Added -> true - Deleted -> true - Unchanged -> { - val distBackwards = - patchLines.subList(0, idx).indexOfLast { it.type != Unchanged }.let { if (it == -1) null else idx - it } - val distForwards = patchLines.subList(idx, patchLines.size).indexOfFirst { it.type != Unchanged } - .let { if (it == -1) null else it } - (null != distBackwards && distBackwards <= contextLines) || (null != distForwards && distForwards <= contextLines) - } - } - }.map { it.value }.toTypedArray() + /** + * Formats the list of DiffResult into a human-readable string representation. + * This function processes each diff result to format added, deleted, and unchanged lines appropriately, + * including context lines and markers for easier reading. + * + * @param patchLines The list of DiffResult to format. + * @param contextLines The number of context lines to include around changes. + * @return A formatted string representing the diff. + */ + fun formatDiff(patchLines: List, contextLines: Int = 3): String { + val patchList = patchLines.withIndex().filter { (idx, lineDiff) -> + when (lineDiff.type) { + Added -> true + Deleted -> true + Unchanged -> { + val distBackwards = + patchLines.subList(0, idx).indexOfLast { it.type != Unchanged } + .let { if (it == -1) null else idx - it } + val distForwards = patchLines.subList(idx, patchLines.size).indexOfFirst { it.type != Unchanged } + .let { if (it == -1) null else it } + (null != distBackwards && distBackwards <= contextLines) || (null != distForwards && distForwards <= contextLines) + } + } + }.map { it.value }.toTypedArray() - return patchList.withIndex().joinToString("\n") { (idx, lineDiff) -> - when { - idx == 0 -> "" - lineDiff.type != Unchanged || patchList[idx - 1].type != Unchanged -> "" - patchList[idx - 1].lineNumber + 1 < lineDiff.lineNumber -> "...\n" - else -> "" - } + when (lineDiff.type) { - Added -> "+ ${lineDiff.line}" - Deleted -> "- ${lineDiff.line}" - Unchanged -> " ${lineDiff.line}" - } + return patchList.withIndex().joinToString("\n") { (idx, lineDiff) -> + when { + idx == 0 -> "" + lineDiff.type != Unchanged || patchList[idx - 1].type != Unchanged -> "" + patchList[idx - 1].lineNumber + 1 < lineDiff.lineNumber -> "...\n" + else -> "" + } + when (lineDiff.type) { + Added -> "+ ${lineDiff.line}" + Deleted -> "- ${lineDiff.line}" + Unchanged -> " ${lineDiff.line}" + } + } } - } } diff --git a/webui/src/main/kotlin/com/github/simiacryptus/diff/IterativePatchUtil.kt b/webui/src/main/kotlin/com/github/simiacryptus/diff/IterativePatchUtil.kt index 386255b7..ca3fe58c 100644 --- a/webui/src/main/kotlin/com/github/simiacryptus/diff/IterativePatchUtil.kt +++ b/webui/src/main/kotlin/com/github/simiacryptus/diff/IterativePatchUtil.kt @@ -1,7 +1,6 @@ package com.github.simiacryptus.diff import org.apache.commons.text.similarity.LevenshteinDistance -import org.slf4j.LoggerFactory object IterativePatchUtil { @@ -16,7 +15,7 @@ object IterativePatchUtil { ) { override fun toString(): String { val sb = StringBuilder() - when(type) { + when (type) { LineType.CONTEXT -> sb.append(" ") LineType.ADD -> sb.append("+") LineType.DELETE -> sb.append("-") @@ -55,7 +54,7 @@ object IterativePatchUtil { sourceLineBuffer.remove(it) patchedTextBuilder.appendLine(it.line) } - if(sourceLineBuffer.isEmpty()) break + if (sourceLineBuffer.isEmpty()) break val codeLine = sourceLineBuffer.removeFirst() var patchLine = codeLine.matchingLine!! when (patchLine.type) { @@ -68,8 +67,8 @@ object IterativePatchUtil { } } while (patchLine.nextLine?.type == LineType.ADD) { - patchedTextBuilder.appendLine(patchLine?.nextLine?.line) - patchLine = patchLine?.nextLine!! + patchedTextBuilder.appendLine(patchLine.nextLine?.line) + patchLine = patchLine.nextLine!! } } @@ -78,10 +77,12 @@ object IterativePatchUtil { private fun linkUniqueMatchingLines(sourceLines: List, patchLines: List) { val sourceLineMap = sourceLines.groupBy { it.line.trim() } - val patchLineMap = patchLines.filter { when(it.type) { - LineType.ADD -> false // ADD lines are not matched to source lines - else -> true - }}.groupBy { it.line.trim() } + val patchLineMap = patchLines.filter { + when (it.type) { + LineType.ADD -> false // ADD lines are not matched to source lines + else -> true + } + }.groupBy { it.line.trim() } sourceLineMap.keys.intersect(patchLineMap.keys).forEach { key -> val sourceLine = sourceLineMap[key]?.singleOrNull() @@ -137,10 +138,12 @@ object IterativePatchUtil { var bestDistance = Int.MAX_VALUE var bestCombinedDistance = Int.MAX_VALUE - for (patchLine in patchLines.filter { when(it.type) { - LineType.ADD -> false // ADD lines are not matched to source lines - else -> true - }}) { + for (patchLine in patchLines.filter { + when (it.type) { + LineType.ADD -> false // ADD lines are not matched to source lines + else -> true + } + }) { if (patchLine.matchingLine != null) continue // Skip lines that already have matches val distance = levenshteinDistance.apply(sourceLine.line.trim(), patchLine.line.trim()) @@ -205,6 +208,5 @@ object IterativePatchUtil { ) }) - private val log = LoggerFactory.getLogger(ApxPatchUtil::class.java) } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/Acceptable.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/Acceptable.kt index 039de070..4ff7111a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/Acceptable.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/Acceptable.kt @@ -10,178 +10,178 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference class Acceptable( - private val task: SessionTask, - private val userMessage: String, - private val initialResponse: (String) -> T, - private val outputFn: (T) -> String, - private val ui: ApplicationInterface, - private val reviseResponse: (List>) -> T, - private val atomicRef: AtomicReference = AtomicReference(), - private val semaphore: Semaphore = Semaphore(0), - private val heading: String + private val task: SessionTask, + private val userMessage: String, + private val initialResponse: (String) -> T, + private val outputFn: (T) -> String, + private val ui: ApplicationInterface, + private val reviseResponse: (List>) -> T, + private val atomicRef: AtomicReference = AtomicReference(), + private val semaphore: Semaphore = Semaphore(0), + private val heading: String ) : Callable { - val tabs = object : TabbedDisplay(task) { - override fun renderTabButtons(): String { - return """ + val tabs = object : TabbedDisplay(task) { + override fun renderTabButtons(): String { + return """
${ - tabs.withIndex().joinToString("\n") - { (index: Int, t: Pair) -> - """""" - } - } + tabs.withIndex().joinToString("\n") + { (index: Int, t: Pair) -> + """""" + } + } ${ - ui.hrefLink("♻") { - val idx: Int = size - set(label(idx), "Retrying...") - task.add("") - main(idx, this@Acceptable.task) - } - } + ui.hrefLink("♻") { + val idx: Int = size + set(label(idx), "Retrying...") + task.add("") + main(idx, this@Acceptable.task) + } + }
""".trimIndent() + } } - } - private val acceptGuard = AtomicBoolean(false) + private val acceptGuard = AtomicBoolean(false) - fun main(tabIndex: Int = tabs.size, task: SessionTask = this.task) { - try { - val history = mutableListOf>() - history.add(userMessage to Role.user) - val design = initialResponse(userMessage) - history.add(outputFn(design) to Role.assistant) - val tabLabel = tabs.label(tabIndex) - val tabContent = tabs[tabLabel] ?: tabs.set(tabLabel, "") + fun main(tabIndex: Int = tabs.size, task: SessionTask = this.task) { + try { + val history = mutableListOf>() + history.add(userMessage to Role.user) + val design = initialResponse(userMessage) + history.add(outputFn(design) to Role.assistant) + val tabLabel = tabs.label(tabIndex) + val tabContent = tabs[tabLabel] ?: tabs.set(tabLabel, "") - if (tabs.size > tabIndex) { - tabContent.append(outputFn(design) + "\n" + feedbackForm(tabIndex, tabContent, design, history, task)) - } else { - tabContent.set(outputFn(design) + "\n" + feedbackForm(tabIndex, tabContent, design, history, task)) - } - tabs.update() - } catch (e: Throwable) { - task.error(ui, e) - task.complete(ui.hrefLink("🔄 Retry") { - main(task = task) - }) + if (tabs.size > tabIndex) { + tabContent.append(outputFn(design) + "\n" + feedbackForm(tabIndex, tabContent, design, history, task)) + } else { + tabContent.set(outputFn(design) + "\n" + feedbackForm(tabIndex, tabContent, design, history, task)) + } + tabs.update() + } catch (e: Throwable) { + task.error(ui, e) + task.complete(ui.hrefLink("🔄 Retry") { + main(task = task) + }) + } } - } - private fun feedbackForm( - tabIndex: Int?, - tabContent: StringBuilder, - design: T, - history: List>, - task: SessionTask, - ): String = """ + private fun feedbackForm( + tabIndex: Int?, + tabContent: StringBuilder, + design: T, + history: List>, + task: SessionTask, + ): String = """ | |
- |${acceptLink(tabIndex, tabContent, design)!!} + |${acceptLink(tabIndex, tabContent, design)} |
- |${textInput(design, tabContent, history, task)!!} + |${textInput(design, tabContent, history, task)} | """.trimMargin() - private fun acceptLink( - tabIndex: Int?, - tabContent: StringBuilder, - design: T, - ) = ui.hrefLink("Accept", classname = "href-link cmd-button") { - accept(tabIndex, tabContent, design) - } - - private fun textInput( - design: T, - tabContent: StringBuilder, - history: List>, - task: SessionTask, - ): String { - val feedbackGuard = AtomicBoolean(false) - return ui.textInput { userResponse -> - if (feedbackGuard.getAndSet(true)) return@textInput - try { - feedback(tabContent, userResponse, history, design, task) - } catch (e: Exception) { - task.error(ui, e) - throw e - } finally { - feedbackGuard.set(false) - } + private fun acceptLink( + tabIndex: Int?, + tabContent: StringBuilder, + design: T, + ) = ui.hrefLink("Accept", classname = "href-link cmd-button") { + accept(tabIndex, tabContent, design) } - } - private fun feedback( - tabContent: StringBuilder, - userResponse: String, - history: List>, - design: T, - task: SessionTask, - ) { - var history = history - history = history + (userResponse to Role.user) - val prevValue = tabContent.toString() - val newValue = (prevValue.substringBefore("") - + "" - + prevValue.substringAfter("") - + "
" - + renderMarkdown(userResponse, ui=ui) - + "
") - tabContent.set(newValue) - task.add("") // Show spinner - tabs.update() - val newDesign = reviseResponse(history) - val newTask = ui.newTask(root = false) - tabContent.set(newValue + "\n" + newTask.placeholder) - tabs.update() - task.complete() - Retryable(ui, newTask) { - outputFn(newDesign) + "\n" + feedbackForm( - tabIndex = null, - tabContent = it, - design = newDesign, - history = history, - task = newTask - ) - }.apply { - set(label(size), process(container)) + private fun textInput( + design: T, + tabContent: StringBuilder, + history: List>, + task: SessionTask, + ): String { + val feedbackGuard = AtomicBoolean(false) + return ui.textInput { userResponse -> + if (feedbackGuard.getAndSet(true)) return@textInput + try { + feedback(tabContent, userResponse, history, design, task) + } catch (e: Exception) { + task.error(ui, e) + throw e + } finally { + feedbackGuard.set(false) + } + } } - } - private fun accept(tabIndex: Int?, tabContent: StringBuilder, design: T) { - if (acceptGuard.getAndSet(true)) { - return - } - try { - if(null != tabIndex) tabs.selectedTab = tabIndex - tabContent?.apply { - val prevTab = toString() - val newValue = - prevTab.substringBefore("") + "" + prevTab.substringAfter( - "" - ) - set(newValue) + private fun feedback( + tabContent: StringBuilder, + userResponse: String, + history: List>, + design: T, + task: SessionTask, + ) { + var history = history + history = history + (userResponse to Role.user) + val prevValue = tabContent.toString() + val newValue = (prevValue.substringBefore("") + + "" + + prevValue.substringAfter("") + + "
" + + renderMarkdown(userResponse, ui = ui) + + "
") + tabContent.set(newValue) + task.add("") // Show spinner tabs.update() - } - } catch (e: Exception) { - task.error(ui, e) - acceptGuard.set(false) - throw e + val newDesign = reviseResponse(history) + val newTask = ui.newTask(root = false) + tabContent.set(newValue + "\n" + newTask.placeholder) + tabs.update() + task.complete() + Retryable(ui, newTask) { + outputFn(newDesign) + "\n" + feedbackForm( + tabIndex = null, + tabContent = it, + design = newDesign, + history = history, + task = newTask + ) + }.apply { + set(label(size), process(container)) + } } - atomicRef.set(design) - semaphore.release() - } - override fun call(): T { - task.echo(heading) - main() - semaphore.acquire() - return atomicRef.get() - } + private fun accept(tabIndex: Int?, tabContent: StringBuilder, design: T) { + if (acceptGuard.getAndSet(true)) { + return + } + try { + if (null != tabIndex) tabs.selectedTab = tabIndex + tabContent.apply { + val prevTab = toString() + val newValue = + prevTab.substringBefore("") + "" + prevTab.substringAfter( + "" + ) + set(newValue) + tabs.update() + } + } catch (e: Exception) { + task.error(ui, e) + acceptGuard.set(false) + throw e + } + atomicRef.set(design) + semaphore.release() + } + + override fun call(): T { + task.echo(heading) + main() + semaphore.acquire() + return atomicRef.get() + } } fun java.lang.StringBuilder.set(newValue: String) { - clear() - append(newValue) + clear() + append(newValue) } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt index 98bf195e..54bac055 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt @@ -3,48 +3,48 @@ package com.simiacryptus.skyenet import com.simiacryptus.skyenet.webui.application.ApplicationInterface object AgentPatterns { - private val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) + private val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) - fun displayMapInTabs( - map: Map, - ui: ApplicationInterface? = null, - split: Boolean = map.entries.map { it.value.length + it.key.length }.sum() > 10000 - ) : String = if(split && ui != null) { - val tasks = map.entries.map { (key, value) -> - key to ui.newTask(root = false) - }.toMap() - scheduledThreadPoolExecutor.schedule({ - tasks.forEach { (key, task) -> - task.complete(map[key]!!) - } - }, 200, java.util.concurrent.TimeUnit.MILLISECONDS) - displayMapInTabs(tasks.mapValues { it.value.placeholder }, ui=ui, split = false) - } else { - """ + fun displayMapInTabs( + map: Map, + ui: ApplicationInterface? = null, + split: Boolean = map.entries.map { it.value.length + it.key.length }.sum() > 10000 + ): String = if (split && ui != null) { + val tasks = map.entries.map { (key, value) -> + key to ui.newTask(root = false) + }.toMap() + scheduledThreadPoolExecutor.schedule({ + tasks.forEach { (key, task) -> + task.complete(map[key]!!) + } + }, 200, java.util.concurrent.TimeUnit.MILLISECONDS) + displayMapInTabs(tasks.mapValues { it.value.placeholder }, ui = ui, split = false) + } else { + """ |
|
|${ - map.keys.joinToString("\n") { key -> - """""" - }/*.indent(" ")*/ - } + map.keys.joinToString("\n") { key -> + """""" + }/*.indent(" ")*/ + } |
|${ - map.entries.withIndex().joinToString("\n") { (idx, t) -> - val (key, value) = t - """ + map.entries.withIndex().joinToString("\n") { (idx, t) -> + val (key, value) = t + """ |
"" - } - }" data-tab="$key"> + when { + idx == 0 -> " active" + else -> "" + } + }" data-tab="$key"> |${value/*.indent(" ")*/} |
""".trimMargin() - }/*.indent(" ")*/ - } + }/*.indent(" ")*/ + } |
""".trimMargin() - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt index f66ed436..af81c1ac 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt @@ -4,30 +4,30 @@ import com.simiacryptus.skyenet.webui.application.ApplicationInterface import com.simiacryptus.skyenet.webui.session.SessionTask open class Retryable( - val ui: ApplicationInterface, - task: SessionTask, - val process: (StringBuilder) -> String + val ui: ApplicationInterface, + task: SessionTask, + val process: (StringBuilder) -> String ) : TabbedDisplay(task) { - override fun renderTabButtons(): String = """ + override fun renderTabButtons(): String = """
${ - tabs.withIndex().joinToString("\n") { (index, _) -> - val tabId = "$index" - """""" + tabs.withIndex().joinToString("\n") { (index, _) -> + val tabId = "$index" + """""" + } } - } ${ - ui.hrefLink("♻") { - val idx = tabs.size - val label = label(idx) - val content = StringBuilder("Retrying..." + SessionTask.spinner) - tabs.add(label to content) - update() - val newResult = process(content) - content.clear() - set(label, newResult) + ui.hrefLink("♻") { + val idx = tabs.size + val label = label(idx) + val content = StringBuilder("Retrying..." + SessionTask.spinner) + tabs.add(label to content) + update() + val newResult = process(content) + content.clear() + set(label, newResult) + } } - }
""".trimIndent() diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt index 98a43148..7f54a181 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt @@ -3,78 +3,80 @@ package com.simiacryptus.skyenet import com.simiacryptus.skyenet.webui.session.SessionTask open class TabbedDisplay( - val task: SessionTask, - val tabs: MutableList> = mutableListOf(), + val task: SessionTask, + val tabs: MutableList> = mutableListOf(), ) { - var selectedTab: Int = 0 - companion object { - val log = org.slf4j.LoggerFactory.getLogger(TabbedDisplay::class.java) - } - val size: Int get() = tabs.size - open fun render() = """ + var selectedTab: Int = 0 + + companion object { + val log = org.slf4j.LoggerFactory.getLogger(TabbedDisplay::class.java) + } + + val size: Int get() = tabs.size + open fun render() = """
${renderTabButtons()} ${ tabs.withIndex().joinToString("\n") { (idx, t) -> renderContentTab(t, idx) } - } + }
""".trimIndent() - val container : StringBuilder by lazy { task.add(render())!! } + val container: StringBuilder by lazy { task.add(render())!! } - open fun renderTabButtons() = """ + open fun renderTabButtons() = """
${ - tabs.toMap().keys.withIndex().joinToString("\n") { (idx, key: String) -> - """""" - } + tabs.toMap().keys.withIndex().joinToString("\n") { (idx, key: String) -> + """""" + } }
""".trimIndent() - open fun renderContentTab(t: Pair, idx: Int) = """ + open fun renderContentTab(t: Pair, idx: Int) = """
"" - } + when { + idx == selectedTab -> "active" + else -> "" + } }" data-tab="$idx">${t.second}
""".trimIndent() - operator fun get(i: String) = tabs.toMap()[i] - operator fun set(name: String, content: String) = - when (val index = find(name)) { - null -> { - val stringBuilder = StringBuilder(content) - tabs.add(name to stringBuilder) - update() - stringBuilder - } + operator fun get(i: String) = tabs.toMap()[i] + operator fun set(name: String, content: String) = + when (val index = find(name)) { + null -> { + val stringBuilder = StringBuilder(content) + tabs.add(name to stringBuilder) + update() + stringBuilder + } - else -> { - val stringBuilder = tabs[index].second - stringBuilder.clear() - stringBuilder.append(content) - update() - stringBuilder - } - } + else -> { + val stringBuilder = tabs[index].second + stringBuilder.clear() + stringBuilder.append(content) + update() + stringBuilder + } + } - fun find(name: String) = tabs.withIndex().firstOrNull { it.value.first == name }?.index + fun find(name: String) = tabs.withIndex().firstOrNull { it.value.first == name }?.index - open fun label(i: Int): String { - return "${tabs.size + 1}" - } + open fun label(i: Int): String { + return "${tabs.size + 1}" + } - fun clear() { - tabs.clear() - update() - } + fun clear() { + tabs.clear() + update() + } - open fun update() { - if(container != null) synchronized(container) { - container.clear() - container.append(render()) + open fun update() { + if (container != null) synchronized(container) { + container.clear() + container.append(render()) + } + task.complete() } - task.complete() - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingAgent.kt index 04f73830..ade6a6ec 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingAgent.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingAgent.kt @@ -25,310 +25,318 @@ import kotlin.reflect.KClass open class CodingAgent( - val api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - val ui: ApplicationInterface, - interpreter: KClass, - val symbols: Map, - temperature: Double = 0.1, - val details: String? = null, - val model: ChatModels, - private val mainTask: SessionTask = ui.newTask(), - val actorMap: Map = mapOf( - ActorTypes.CodingActor to CodingActor(interpreter, symbols = symbols, temperature = temperature, details = details, model = model) - ), + val api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + val ui: ApplicationInterface, + interpreter: KClass, + val symbols: Map, + temperature: Double = 0.1, + val details: String? = null, + val model: ChatModels, + private val mainTask: SessionTask = ui.newTask(), + val actorMap: Map = mapOf( + ActorTypes.CodingActor to CodingActor( + interpreter, + symbols = symbols, + temperature = temperature, + details = details, + model = model + ) + ), ) : ActorSystem(actorMap.map { it.key.name to it.value }.toMap(), dataStorage, user, session) { - enum class ActorTypes { - CodingActor - } + enum class ActorTypes { + CodingActor + } - open val actor by lazy { - getActor(ActorTypes.CodingActor) as CodingActor - } + open val actor by lazy { + getActor(ActorTypes.CodingActor) as CodingActor + } - open val canPlay by lazy { - ApplicationServices.authorizationManager.isAuthorized( - this::class.java, - user, - OperationType.Execute - ) - } - val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) - val cachedThreadPoolExecutor = java.util.concurrent.Executors.newCachedThreadPool() + open val canPlay by lazy { + ApplicationServices.authorizationManager.isAuthorized( + this::class.java, + user, + OperationType.Execute + ) + } + val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) + val cachedThreadPoolExecutor = java.util.concurrent.Executors.newCachedThreadPool() - fun start( - userMessage: String, - ) { - try { - mainTask.echo(renderMarkdown(userMessage, ui=ui)) - val codeRequest = codeRequest(listOf(userMessage to ApiModel.Role.user)) - start(codeRequest, mainTask) - } catch (e: Throwable) { - log.warn("Error", e) - mainTask.error(ui, e) + fun start( + userMessage: String, + ) { + try { + mainTask.echo(renderMarkdown(userMessage, ui = ui)) + val codeRequest = codeRequest(listOf(userMessage to ApiModel.Role.user)) + start(codeRequest, mainTask) + } catch (e: Throwable) { + log.warn("Error", e) + mainTask.error(ui, e) + } } - } - fun start( - codeRequest: CodingActor.CodeRequest, - task: SessionTask = mainTask, - ) { - val newTask = ui.newTask(root = false) - task.complete(newTask.placeholder) - Retryable(ui, newTask) { - val newTask = ui.newTask(root = false) - scheduledThreadPoolExecutor.schedule({ - cachedThreadPoolExecutor.submit { - val statusSB = newTask.add("Running...") - displayCode(newTask, codeRequest) - statusSB?.clear() - newTask.complete() + fun start( + codeRequest: CodingActor.CodeRequest, + task: SessionTask = mainTask, + ) { + val newTask = ui.newTask(root = false) + task.complete(newTask.placeholder) + Retryable(ui, newTask) { + val newTask = ui.newTask(root = false) + scheduledThreadPoolExecutor.schedule({ + cachedThreadPoolExecutor.submit { + val statusSB = newTask.add("Running...") + displayCode(newTask, codeRequest) + statusSB?.clear() + newTask.complete() + } + }, 100, TimeUnit.MILLISECONDS) + newTask.placeholder + }.apply { + set(label(size), process(container)) } - }, 100, TimeUnit.MILLISECONDS) - newTask.placeholder - }.apply { - set(label(size), process(container)) } - } - open fun codeRequest(messages: List>) = - CodingActor.CodeRequest(messages) + open fun codeRequest(messages: List>) = + CodingActor.CodeRequest(messages) - fun displayCode( - task: SessionTask, - codeRequest: CodingActor.CodeRequest, - ) { - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if(lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input=codeRequest, - api = api as OpenAIClient, - givenCode=lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = api) - } - displayCodeAndFeedback(task, codeRequest, codeResponse) - } catch (e: Throwable) { - log.warn("Error", e) - } - } - protected fun displayCodeAndFeedback( - task: SessionTask, - codeRequest: CodingActor.CodeRequest, - response: CodeResult, - ) { - try { - displayCode(task, response) - displayFeedback(task, append(codeRequest, response), response) - } catch (e: Throwable) { - task.error(ui, e) - log.warn("Error", e) + fun displayCode( + task: SessionTask, + codeRequest: CodingActor.CodeRequest, + ) { + try { + val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = api as OpenAIClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = api) + } + displayCodeAndFeedback(task, codeRequest, codeResponse) + } catch (e: Throwable) { + log.warn("Error", e) + } } - } - fun append( - codeRequest: CodingActor.CodeRequest, - response: CodeResult - ) = codeRequest( - messages = codeRequest.messages + - listOf( - response.code to ApiModel.Role.assistant, - ).filter { it.first.isNotBlank() } - ) + protected fun displayCodeAndFeedback( + task: SessionTask, + codeRequest: CodingActor.CodeRequest, + response: CodeResult, + ) { + try { + displayCode(task, response) + displayFeedback(task, append(codeRequest, response), response) + } catch (e: Throwable) { + task.error(ui, e) + log.warn("Error", e) + } + } - fun displayCode( - task: SessionTask, - response: CodeResult - ) { - task.hideable(ui, - renderMarkdown( - response.renderedResponse ?: - //language=Markdown - "```${actor.language.lowercase(Locale.getDefault())}\n${response.code.trim()}\n```", ui=ui - ) + fun append( + codeRequest: CodingActor.CodeRequest, + response: CodeResult + ) = codeRequest( + messages = codeRequest.messages + + listOf( + response.code to ApiModel.Role.assistant, + ).filter { it.first.isNotBlank() } ) - } - open fun displayFeedback( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + fun displayCode( + task: SessionTask, + response: CodeResult + ) { + task.hideable( + ui, + renderMarkdown( + response.renderedResponse ?: + //language=Markdown + "```${actor.language.lowercase(Locale.getDefault())}\n${response.code.trim()}\n```", ui = ui + ) + ) + } + + open fun displayFeedback( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${if (!canPlay) "" else playButton(task, request, response, formText) { formHandle!! }} |
|${reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - protected fun reviseMsg( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = ui.textInput { feedback -> - responseAction(task, "Revising...", formHandle(), formText) { - feedback(task, feedback, request, response) + protected fun reviseMsg( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = ui.textInput { feedback -> + responseAction(task, "Revising...", formHandle(), formText) { + feedback(task, feedback, request, response) + } } - } - protected fun regenButton( - task: SessionTask, - request: CodingActor.CodeRequest, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = "" + protected fun regenButton( + task: SessionTask, + request: CodingActor.CodeRequest, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = "" - protected fun playButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = if (!canPlay) "" else - ui.hrefLink("▶", "href-link play-button"){ - responseAction(task, "Running...", formHandle(), formText) { - execute(task, response, request) - } - } + protected fun playButton( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = if (!canPlay) "" else + ui.hrefLink("▶", "href-link play-button") { + responseAction(task, "Running...", formHandle(), formText) { + execute(task, response, request) + } + } - protected open fun responseAction( - task: SessionTask, - message: String, - formHandle: StringBuilder?, - formText: StringBuilder, - fn: () -> Unit = {} - ) { - formHandle?.clear() - val header = task.header(message) - try { - fn() - } finally { - header?.clear() - revertButton(task, formHandle, formText) + protected open fun responseAction( + task: SessionTask, + message: String, + formHandle: StringBuilder?, + formText: StringBuilder, + fn: () -> Unit = {} + ) { + formHandle?.clear() + val header = task.header(message) + try { + fn() + } finally { + header?.clear() + revertButton(task, formHandle, formText) + } } - } - protected open fun revertButton( - task: SessionTask, - formHandle: StringBuilder?, - formText: StringBuilder - ): StringBuilder? { - var revertButton: StringBuilder? = null - revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button"){ - revertButton?.clear() - formHandle?.append(formText) - task.complete() - }) - return revertButton - } + protected open fun revertButton( + task: SessionTask, + formHandle: StringBuilder?, + formText: StringBuilder + ): StringBuilder? { + var revertButton: StringBuilder? = null + revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button") { + revertButton?.clear() + formHandle?.append(formText) + task.complete() + }) + return revertButton + } - protected open fun feedback( - task: SessionTask, - feedback: String, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - try { - task.echo(renderMarkdown(feedback, ui=ui)) - start(codeRequest = codeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() }.map { it.first to it.second } - ), task = task) - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) + protected open fun feedback( + task: SessionTask, + feedback: String, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + try { + task.echo(renderMarkdown(feedback, ui = ui)) + start(codeRequest = codeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + feedback to ApiModel.Role.user, + ).filter { it.first.isNotBlank() }.map { it.first to it.second } + ), task = task) + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) + } } - } - protected open fun execute( - task: SessionTask, - response: CodeResult, - request: CodingActor.CodeRequest, - ) { - try { - val result = execute(task, response) - displayFeedback(task, codeRequest( - messages = request.messages + - listOf( - "Running...\n\n$result" to ApiModel.Role.assistant, - ).filter { it.first.isNotBlank() } - ), response) - } catch (e: Throwable) { - handleExecutionError(e, task, request, response) + protected open fun execute( + task: SessionTask, + response: CodeResult, + request: CodingActor.CodeRequest, + ) { + try { + val result = execute(task, response) + displayFeedback(task, codeRequest( + messages = request.messages + + listOf( + "Running...\n\n$result" to ApiModel.Role.assistant, + ).filter { it.first.isNotBlank() } + ), response) + } catch (e: Throwable) { + handleExecutionError(e, task, request, response) + } } - } - protected open fun handleExecutionError( - e: Throwable, - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - val message = when { - e is ValidatedObject.ValidationError -> renderMarkdown(e.message ?: "", ui=ui) - e is CodingActor.FailedToImplementException -> renderMarkdown( - """ + protected open fun handleExecutionError( + e: Throwable, + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + val message = when { + e is ValidatedObject.ValidationError -> renderMarkdown(e.message ?: "", ui = ui) + e is CodingActor.FailedToImplementException -> renderMarkdown( + """ |**Failed to Implement** | |${e.message} | - |""".trimMargin(), ui=ui - ) + |""".trimMargin(), ui = ui + ) - else -> renderMarkdown( - """ + else -> renderMarkdown( + """ |**Error `${e.javaClass.name}`** | |```text |${e.stackTraceToString()/*.indent(" ")*/} |``` - |""".trimMargin(), ui=ui - ) + |""".trimMargin(), ui = ui + ) + } + task.add(message, true, "div", "error") + displayCode(task, CodingActor.CodeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + message to ApiModel.Role.system, + ).filter { it.first.isNotBlank() } + )) } - task.add(message, true, "div", "error") - displayCode(task, CodingActor.CodeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - message to ApiModel.Role.system, - ).filter { it.first.isNotBlank() } - )) - } - fun execute( - task: SessionTask, - response: CodeResult - ): String { - val resultValue = response.result.resultValue - val resultOutput = response.result.resultOutput - val result = when { - resultValue.isBlank() || resultValue.trim().lowercase() == "null" -> """ + fun execute( + task: SessionTask, + response: CodeResult + ): String { + val resultValue = response.result.resultValue + val resultOutput = response.result.resultOutput + val result = when { + resultValue.isBlank() || resultValue.trim().lowercase() == "null" -> """ |# Output |```text |${resultOutput.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` """.trimMargin() - else -> """ + else -> """ |# Result |``` |${resultValue.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} @@ -339,12 +347,12 @@ open class CodingAgent( |${resultOutput.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` """.trimMargin() + } + task.add(renderMarkdown(result, ui = ui)) + return result } - task.add(renderMarkdown(result, ui=ui)) - return result - } - companion object { - private val log = LoggerFactory.getLogger(CodingAgent::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(CodingAgent::class.java) + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingSubAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingSubAgent.kt deleted file mode 100644 index 0bc1b3fc..00000000 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/CodingSubAgent.kt +++ /dev/null @@ -1,69 +0,0 @@ -package com.simiacryptus.skyenet.apps.coding - -import com.simiacryptus.jopenai.API -import com.simiacryptus.jopenai.models.ChatModels -import com.simiacryptus.skyenet.core.actors.CodingActor -import com.simiacryptus.skyenet.core.platform.Session -import com.simiacryptus.skyenet.core.platform.StorageInterface -import com.simiacryptus.skyenet.core.platform.User -import com.simiacryptus.skyenet.interpreter.Interpreter -import com.simiacryptus.skyenet.webui.application.ApplicationInterface -import com.simiacryptus.skyenet.webui.session.SessionTask -import java.util.concurrent.Semaphore -import kotlin.reflect.KClass - -open class CodingSubAgent( - api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - ui: ApplicationInterface, - interpreter: KClass, - symbols: Map, - temperature: Double = 0.1, - details: String? = null, - model: ChatModels, - mainTask: SessionTask = ui.newTask(), - val semaphore: Semaphore = Semaphore(0), -) : CodingAgent( - api = api, - dataStorage = dataStorage, - session = session, - user = user, - ui = ui, - interpreter = interpreter, - symbols = symbols, - temperature = temperature, - details = details, - model = model, - mainTask = mainTask -) { - override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodingActor.CodeResult) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ - |
- |${if (!super.canPlay) "" else super.playButton(task, request, response, formText) { formHandle!! }} - |${acceptButton(task, request, response, formText) { formHandle!! }} - |
- |${super.reviseMsg(task, request, response, formText) { formHandle!! }} - """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } - - protected fun acceptButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodingActor.CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = if (!canPlay) "" else - ui.hrefLink("\uD83D\uDE80", "href-link play-button"){ - semaphore.release() - } - -} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ShellToolAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ShellToolAgent.kt index b7032298..4b172238 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ShellToolAgent.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ShellToolAgent.kt @@ -34,41 +34,54 @@ import java.io.File import kotlin.reflect.KClass private val String.escapeQuotedString: String - get() = replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("$", "\\$") + get() = replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("$", "\\$") abstract class ShellToolAgent( - api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - ui: ApplicationInterface, - interpreter: KClass, - symbols: Map, - temperature: Double = 0.1, - details: String? = null, - model: ChatModels, - actorMap: Map = mapOf( - ActorTypes.CodingActor to CodingActor( - interpreter, - symbols = symbols, - temperature = temperature, - details = details, - model = model - ) - ), - mainTask: SessionTask = ui.newTask(), -) : CodingAgent(api, dataStorage, session, user, ui, interpreter, symbols, temperature, details, model, mainTask, actorMap) { + api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + ui: ApplicationInterface, + interpreter: KClass, + symbols: Map, + temperature: Double = 0.1, + details: String? = null, + model: ChatModels, + actorMap: Map = mapOf( + ActorTypes.CodingActor to CodingActor( + interpreter, + symbols = symbols, + temperature = temperature, + details = details, + model = model + ) + ), + mainTask: SessionTask = ui.newTask(), +) : CodingAgent( + api, + dataStorage, + session, + user, + ui, + interpreter, + symbols, + temperature, + details, + model, + mainTask, + actorMap +) { - override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${if (!canPlay) "" else playButton(task, request, response, formText) { formHandle!! }} |${super.regenButton(task, request, formText) { formHandle!! }} @@ -76,41 +89,41 @@ abstract class ShellToolAgent( |
|${super.reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - private var lastResult: String? = null + private var lastResult: String? = null - private fun createToolButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button"){ - val task = ui.newTask() - responseAction(task, "Exporting...", formHandle(), formText) { - displayCodeFeedback( - task, schemaActor(), request.copy( - messages = listOf( - response.code to ApiModel.Role.assistant, - "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user - ) - ) - ) { schemaCode -> - val command = actor.symbols.get("command")?.let { command -> - when (command) { - is String -> command.split(" ") - is List<*> -> command.map { it.toString() } - else -> throw IllegalArgumentException("Invalid command: $command") - } - } ?: listOf("bash") - val cwd = actor.symbols.get("workingDir")?.toString()?.let { java.io.File(it) } ?: java.io.File(".") - val env = actor.symbols.get("env")?.let { env -> (env as Map) } ?: mapOf() - val codePrefix = """ + private fun createToolButton( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button") { + val task = ui.newTask() + responseAction(task, "Exporting...", formHandle(), formText) { + displayCodeFeedback( + task, schemaActor(), request.copy( + messages = listOf( + response.code to ApiModel.Role.assistant, + "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user + ) + ) + ) { schemaCode -> + val command = actor.symbols.get("command")?.let { command -> + when (command) { + is String -> command.split(" ") + is List<*> -> command.map { it.toString() } + else -> throw IllegalArgumentException("Invalid command: $command") + } + } ?: listOf("bash") + val cwd = actor.symbols.get("workingDir")?.toString()?.let { File(it) } ?: File(".") + val env = actor.symbols.get("env")?.let { env -> (env as Map) } ?: mapOf() + val codePrefix = """ fun execute() : Pair { val command = "${command.joinToString(" ").escapeQuotedString}".split(" ") val cwd = java.io.File("${cwd.absolutePath.escapeQuotedString}") @@ -131,352 +144,388 @@ abstract class ShellToolAgent( } } """.trimIndent() - val messages = listOf( - "Shell Code: \n```${actor.language}\n${(response.code)/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, - ) + (lastResult?.let { listOf( - "Example Output:\n\n```text\n${it/*.indent(" ")*/}\n```" to ApiModel.Role.assistant - ) } ?: listOf()) + listOf( - "Schema: \n```kotlin\n${schemaCode/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, - "Implement a parsing method to convert the shell output to the requested data structure" to ApiModel.Role.user - ) - displayCodeFeedback( - task, parsedActor(), request.copy( - messages = messages, - codePrefix = codePrefix - ) - ) { parsedCode -> - displayCodeFeedback( - task, servletActor(), request.copy( - messages = listOf( - (codePrefix + "\n\n" + parsedCode) to ApiModel.Role.assistant, - "Reprocess this code prototype into a servlet. " + - "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user - ), - codePrefix = schemaCode - ) - ) { servletHandler -> - val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() - val toolsPrefix = "/tools" - var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> - openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) - } - task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui=ui)) - for (i in 0..5) { - try { - OpenAPIGenerator.main( - arrayOf( - "generate", - "-i", - File.createTempFile("openapi", ".json").apply { - writeText(JsonUtil.toJson(openAPI)) - deleteOnExit() - }.absolutePath, - "-g", - "html2", - "-o", - File(dataStorage.getSessionDir(user, session), "openapi/html2").apply { mkdirs() }.absolutePath, - ) + val messages = listOf( + "Shell Code: \n```${actor.language}\n${(response.code)/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, + ) + (lastResult?.let { + listOf( + "Example Output:\n\n```text\n${it/*.indent(" ")*/}\n```" to ApiModel.Role.assistant + ) + } ?: listOf()) + listOf( + "Schema: \n```kotlin\n${schemaCode/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, + "Implement a parsing method to convert the shell output to the requested data structure" to ApiModel.Role.user ) - task.add("Validated OpenAPI Descriptor - Documentation Saved"); - break; - } catch (e: SpecValidationException) { - val error = """ + displayCodeFeedback( + task, parsedActor(), request.copy( + messages = messages, + codePrefix = codePrefix + ) + ) { parsedCode -> + displayCodeFeedback( + task, servletActor(), request.copy( + messages = listOf( + (codePrefix + "\n\n" + parsedCode) to ApiModel.Role.assistant, + "Reprocess this code prototype into a servlet. " + + "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user + ), + codePrefix = schemaCode + ) + ) { servletHandler -> + val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() + val toolsPrefix = "/tools" + var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> + openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + } + task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui = ui)) + for (i in 0..5) { + try { + OpenAPIGenerator.main( + arrayOf( + "generate", + "-i", + File.createTempFile("openapi", ".json").apply { + writeText(JsonUtil.toJson(openAPI)) + deleteOnExit() + }.absolutePath, + "-g", + "html2", + "-o", + File( + dataStorage.getSessionDir(user, session), + "openapi/html2" + ).apply { mkdirs() }.absolutePath, + ) + ) + task.add("Validated OpenAPI Descriptor - Documentation Saved") + break + } catch (e: SpecValidationException) { + val error = """ |${e.message} |${e.errors.joinToString("\n") { "ERROR:" + it.toString() }} |${e.warnings.joinToString("\n") { "WARN:" + it.toString() }} """.trimIndent() - task.hideable(ui, renderMarkdown("```\n${error?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui=ui)) - openAPI = openAPIParsedActor().answer( - listOf( - servletImpl, - JsonUtil.toJson(openAPI), - error - ), api - ).obj.let { openApi -> - val paths = HashMap(openApi.paths) - openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + task.hideable( + ui, + renderMarkdown( + "```\n${error.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", + ui = ui + ) + ) + openAPI = openAPIParsedActor().answer( + listOf( + servletImpl, + JsonUtil.toJson(openAPI), + error + ), api + ).obj.let { openApi -> + val paths = HashMap(openApi.paths) + openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + } + task.hideable( + ui, + renderMarkdown( + "```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", + ui = ui + ) + ) + } + } + if (ApplicationServices.authorizationManager.isAuthorized( + ToolAgent.javaClass, + user, + AuthorizationInterface.OperationType.Admin + ) + ) { + ToolServlet.addTool( + ToolServlet.Tool( + path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", + openApiDescription = openAPI, + interpreterString = getInterpreterString(), + servletCode = servletImpl + ) + ) + } + buildTestPage(openAPI, servletImpl, task) + } } - task.hideable(ui, renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui=ui)) - } - } - if (ApplicationServices.authorizationManager.isAuthorized( - ToolAgent.javaClass, - user, - AuthorizationInterface.OperationType.Admin - ) - ) { - ToolServlet.addTool( - ToolServlet.Tool( - path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", - openApiDescription = openAPI, - interpreterString = getInterpreterString(), - servletCode = servletImpl - ) - ) } - buildTestPage(openAPI, servletImpl, task) - } } - } } - } - private fun openAPIParsedActor() = object : ParsedActor( + private fun openAPIParsedActor() = object : ParsedActor( // parserClass = OpenApiParser::class.java, - resultClass = OpenAPI::class.java, - model = model, - prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", - parsingModel = model, - ) { - override val describer: TypeDescriber - get() = object : AbbrevWhitelistYamlDescriber( - //"com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false - } - } + resultClass = OpenAPI::class.java, + model = model, + prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", + parsingModel = model, + ) { + override val describer: TypeDescriber + get() = object : AbbrevWhitelistYamlDescriber( + //"com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false + } + } - private fun servletActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = actor.symbols + mapOf( - "returnBuffer" to ServletBuffer(), - "json" to JsonUtil, - "req" to Request(null, null), - "resp" to Response(null, null), - ), - describer = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", - "com.github.simiacryptus" + private fun servletActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = actor.symbols + mapOf( + "returnBuffer" to ServletBuffer(), + "json" to JsonUtil, + "req" to Request(null, null), + "resp" to Response(null, null), + ), + describer = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", + "com.github.simiacryptus" + ) { + override fun describe( + rawType: Class, + stackMax: Int, + describedTypes: MutableSet + ): String = when (rawType) { + Request::class.java -> describe(HttpServletRequest::class.java) + Response::class.java -> describe(HttpServletResponse::class.java) + else -> super.describe(rawType, stackMax, describedTypes) + } + }, + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols ) { - override fun describe(rawType: Class, - stackMax: Int, - describedTypes: MutableSet): String = when (rawType) { - Request::class.java -> describe(HttpServletRequest::class.java) - Response::class.java -> describe(HttpServletResponse::class.java) - else -> super.describe(rawType, stackMax, describedTypes) - } - }, - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + override val prompt: String + get() = super.prompt + } - private fun schemaActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = mapOf(), - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + private fun schemaActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = mapOf(), + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - private fun parsedActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = mapOf(), - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + private fun parsedActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = mapOf(), + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - /** - * - * TODO: This method seems redundant. - * - * */ - private fun displayCodeFeedback( - task: SessionTask, - actor: CodingActor, - request: CodingActor.CodeRequest, - response: CodeResult = execWrap { actor.answer(request, api = api) }, - onComplete: (String) -> Unit - ) { - task.hideable(ui, renderMarkdown("```kotlin\n${response.code?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui=ui)) - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + /** + * + * TODO: This method seems redundant. + * + * */ + private fun displayCodeFeedback( + task: SessionTask, + actor: CodingActor, + request: CodingActor.CodeRequest, + response: CodeResult = execWrap { actor.answer(request, api = api) }, + onComplete: (String) -> Unit + ) { + task.hideable( + ui, + renderMarkdown("```kotlin\n${response.code.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui = ui) + ) + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${ - super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button"){ - super.responseAction(task, "Accepted...", formHandle!!, formText) { - onComplete(response.code) - } - } - } + super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button") { + super.responseAction(task, "Accepted...", formHandle!!, formText) { + onComplete(response.code) + } + } + } |${ - if (!super.canPlay) "" else - ui.hrefLink("▶", "href-link play-button"){ - execute(ui.newTask(), response) - } - } + if (!super.canPlay) "" else + ui.hrefLink("▶", "href-link play-button") { + execute(ui.newTask(), response) + } + } |${ - super.ui.hrefLink("♻", "href-link regen-button"){ - super.responseAction(task, "Regenerating...", formHandle!!, formText) { - //val task = super.ui.newTask() - val codeRequest = - request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as OpenAIClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) - } - super.displayCode(task, codeResponse) - displayCodeFeedback(task, actor, super.append(codeRequest, codeResponse), codeResponse, onComplete) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button"){ - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) + super.ui.hrefLink("♻", "href-link regen-button") { + super.responseAction(task, "Regenerating...", formHandle!!, formText) { + //val task = super.ui.newTask() + val codeRequest = + request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as OpenAIClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) + } + super.displayCode(task, codeResponse) + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) + } + } + } } - } - } - } |
|${ - super.ui.textInput { feedback -> - super.responseAction(task, "Revising...", formHandle!!, formText) { - //val task = super.ui.newTask() - try { - task.echo(renderMarkdown(feedback, ui=ui)) - val codeRequest = CodingActor.CodeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() }.map { it.first to it.second } - ) - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as OpenAIClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) + super.ui.textInput { feedback -> + super.responseAction(task, "Revising...", formHandle!!, formText) { + //val task = super.ui.newTask() + try { + task.echo(renderMarkdown(feedback, ui = ui)) + val codeRequest = CodingActor.CodeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + feedback to ApiModel.Role.user, + ).filter { it.first.isNotBlank() }.map { it.first to it.second } + ) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as OpenAIClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) + } + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) + } + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) + } + } } - displayCodeFeedback(task, actor, super.append(codeRequest, codeResponse), codeResponse, onComplete) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button"){ - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) - } - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) } - } - } - } """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - class ServletBuffer : ArrayList() + class ServletBuffer : ArrayList() - private fun buildTestPage( - openAPI: OpenAPI, - servletImpl: String, - task: SessionTask - ) { - var testPage = SimpleActor( - prompt = "Given the definition for a servlet handler, create a test page that can be used to test the servlet", - model = model, - ).answer( - listOf( - JsonUtil.toJson(openAPI), - servletImpl - ), api - ) - // if ```html unwrap - if (testPage.contains("```html")) testPage = testPage.substringAfter("```html").substringBefore("```") - task.add(renderMarkdown("```html\n${testPage?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui=ui)) - task.complete( - "Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" ) - }'>Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" - ) - } + } - abstract fun getInterpreterString(): String; + abstract fun getInterpreterString(): String - private fun answer( - actor: CodingActor, - request: CodingActor.CodeRequest, - task: SessionTask = ui.newTask(), - feedback: Boolean = true, - ): CodeResult { - val response = actor.answer(request, api = api) - if (feedback) displayCodeAndFeedback(task, request, response) - else displayCode(task, response) - return response - } + private fun answer( + actor: CodingActor, + request: CodingActor.CodeRequest, + task: SessionTask = ui.newTask(), + feedback: Boolean = true, + ): CodeResult { + val response = actor.answer(request, api = api) + if (feedback) displayCodeAndFeedback(task, request, response) + else displayCode(task, response) + return response + } - companion object { - val log = LoggerFactory.getLogger(ShellToolAgent::class.java) - fun execWrap(fn: () -> T): T { - val classLoader = Thread.currentThread().contextClassLoader - val prevCL = KotlinInterpreter.classLoader - KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader - return try { - WebAppClassLoader.runWithServerClassAccess { - require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) - require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) - // com.simiacryptus.jopenai.OpenAIClient - require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) - require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) - fn() + companion object { + val log = LoggerFactory.getLogger(ShellToolAgent::class.java) + fun execWrap(fn: () -> T): T { + val classLoader = Thread.currentThread().contextClassLoader + val prevCL = KotlinInterpreter.classLoader + KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader + return try { + WebAppClassLoader.runWithServerClassAccess { + require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) + require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) + // com.simiacryptus.jopenai.OpenAIClient + require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) + require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) + fn() + } + } finally { + KotlinInterpreter.classLoader = prevCL + } } - } finally { - KotlinInterpreter.classLoader = prevCL - } } - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ToolAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ToolAgent.kt index 9de2649a..e960f07d 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ToolAgent.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/coding/ToolAgent.kt @@ -34,32 +34,45 @@ import java.io.File import kotlin.reflect.KClass abstract class ToolAgent( - api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - ui: ApplicationInterface, - interpreter: KClass, - symbols: Map, - temperature: Double = 0.1, - details: String? = null, - model: ChatModels, - mainTask: SessionTask = ui.newTask(), - actorMap: Map = mapOf( - ActorTypes.CodingActor to CodingActor( - interpreter, - symbols = symbols, - temperature = temperature, - details = details, - model = model - ) - ), -) : CodingAgent(api, dataStorage, session, user, ui, interpreter, symbols, temperature, details, model, mainTask, actorMap) { - override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + ui: ApplicationInterface, + interpreter: KClass, + symbols: Map, + temperature: Double = 0.1, + details: String? = null, + model: ChatModels, + mainTask: SessionTask = ui.newTask(), + actorMap: Map = mapOf( + ActorTypes.CodingActor to CodingActor( + interpreter, + symbols = symbols, + temperature = temperature, + details = details, + model = model + ) + ), +) : CodingAgent( + api, + dataStorage, + session, + user, + ui, + interpreter, + symbols, + temperature, + details, + model, + mainTask, + actorMap +) { + override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${super.playButton(task, request, response, formText) { formHandle!! }} |${super.regenButton(task, request, formText) { formHandle!! }} @@ -67,336 +80,358 @@ abstract class ToolAgent( |
|${super.reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } - - private fun createToolButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button"){ - val task = ui.newTask() - responseAction(task, "Exporting...", formHandle(), formText) { - displayCodeFeedback( - task, schemaActor(), request.copy( - messages = listOf( - response.code to ApiModel.Role.assistant, - "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user - ) ) - ) { schemaCode -> - displayCodeFeedback( - task, servletActor(), request.copy( - messages = listOf( - response.code to ApiModel.Role.assistant, - "Reprocess this code prototype into a servlet using the given data schema. " + - "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user - ), - codePrefix = schemaCode - ) - ) { servletHandler -> - val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() - val toolsPrefix = "/tools" - var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> - openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) - } - task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui=ui)) - for (i in 0..5) { - try { - OpenAPIGenerator.main( - arrayOf( - "generate", - "-i", - File.createTempFile("openapi", ".json").apply { - writeText(JsonUtil.toJson(openAPI)) - deleteOnExit() - }.absolutePath, - "-g", - "html2", - "-o", - File(dataStorage.getSessionDir(user, session), "openapi/html2").apply { mkdirs() }.absolutePath, + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } + + private fun createToolButton( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button") { + val task = ui.newTask() + responseAction(task, "Exporting...", formHandle(), formText) { + displayCodeFeedback( + task, schemaActor(), request.copy( + messages = listOf( + response.code to ApiModel.Role.assistant, + "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user + ) ) - ) - task.add("Validated OpenAPI Descriptor - Documentation Saved"); - break; - } catch (e: SpecValidationException) { - val error = """ + ) { schemaCode -> + displayCodeFeedback( + task, servletActor(), request.copy( + messages = listOf( + response.code to ApiModel.Role.assistant, + "Reprocess this code prototype into a servlet using the given data schema. " + + "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user + ), + codePrefix = schemaCode + ) + ) { servletHandler -> + val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() + val toolsPrefix = "/tools" + var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> + openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + } + task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui = ui)) + for (i in 0..5) { + try { + OpenAPIGenerator.main( + arrayOf( + "generate", + "-i", + File.createTempFile("openapi", ".json").apply { + writeText(JsonUtil.toJson(openAPI)) + deleteOnExit() + }.absolutePath, + "-g", + "html2", + "-o", + File( + dataStorage.getSessionDir(user, session), + "openapi/html2" + ).apply { mkdirs() }.absolutePath, + ) + ) + task.add("Validated OpenAPI Descriptor - Documentation Saved") + break + } catch (e: SpecValidationException) { + val error = """ |${e.message} |${e.errors.joinToString("\n") { "ERROR:" + it.toString() }} |${e.warnings.joinToString("\n") { "WARN:" + it.toString() }} """.trimIndent() - task.hideable(ui, renderMarkdown("```\n${error/*.indent(" ")*/}\n```", ui=ui)) - openAPI = openAPIParsedActor().answer( - listOf( - servletImpl, - JsonUtil.toJson(openAPI), - error - ), api - ).obj.let { openApi -> - val paths = HashMap(openApi.paths) - openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) - } - task.hideable(ui, renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui=ui)) + task.hideable(ui, renderMarkdown("```\n${error/*.indent(" ")*/}\n```", ui = ui)) + openAPI = openAPIParsedActor().answer( + listOf( + servletImpl, + JsonUtil.toJson(openAPI), + error + ), api + ).obj.let { openApi -> + val paths = HashMap(openApi.paths) + openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + } + task.hideable( + ui, + renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui = ui) + ) + } + } + if (ApplicationServices.authorizationManager.isAuthorized( + ToolAgent.javaClass, + user, + AuthorizationInterface.OperationType.Admin + ) + ) { + ToolServlet.addTool( + ToolServlet.Tool( + path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", + openApiDescription = openAPI, + interpreterString = getInterpreterString(), + servletCode = servletImpl + ) + ) + } + buildTestPage(openAPI, servletImpl, task) + } } - } - if (ApplicationServices.authorizationManager.isAuthorized( - ToolAgent.javaClass, - user, - AuthorizationInterface.OperationType.Admin - ) - ) { - ToolServlet.addTool( - ToolServlet.Tool( - path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", - openApiDescription = openAPI, - interpreterString = getInterpreterString(), - servletCode = servletImpl - ) - ) - } - buildTestPage(openAPI, servletImpl, task) } - } } - } - private fun openAPIParsedActor() = object : ParsedActor( + private fun openAPIParsedActor() = object : ParsedActor( // parserClass = OpenApiParser::class.java, - resultClass = OpenAPI::class.java, - model = model, - prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", - parsingModel = model, - ) { - override val describer: TypeDescriber - get() = object : AbbrevWhitelistYamlDescriber( - //"com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false - } - } + resultClass = OpenAPI::class.java, + model = model, + prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", + parsingModel = model, + ) { + override val describer: TypeDescriber + get() = object : AbbrevWhitelistYamlDescriber( + //"com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false + } + } - private fun servletActor() = object : CodingActor( - interpreterClass = actor.interpreterClass, - symbols = actor.symbols + mapOf( - "returnBuffer" to ServletBuffer(), - "json" to JsonUtil, - "req" to Request(null, null), - "resp" to Response(null, null), - ), - describer = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", - "com.github.simiacryptus" + private fun servletActor() = object : CodingActor( + interpreterClass = actor.interpreterClass, + symbols = actor.symbols + mapOf( + "returnBuffer" to ServletBuffer(), + "json" to JsonUtil, + "req" to Request(null, null), + "resp" to Response(null, null), + ), + describer = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", + "com.github.simiacryptus" + ) { + override fun describe( + rawType: Class, + stackMax: Int, + describedTypes: MutableSet + ): String = when (rawType) { + Request::class.java -> describe(HttpServletRequest::class.java) + Response::class.java -> describe(HttpServletResponse::class.java) + else -> super.describe(rawType, stackMax, describedTypes) + } + }, + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols ) { - override fun describe(rawType: Class, - stackMax: Int, - describedTypes: MutableSet): String = when (rawType) { - Request::class.java -> describe(HttpServletRequest::class.java) - Response::class.java -> describe(HttpServletResponse::class.java) - else -> super.describe(rawType, stackMax, describedTypes) - } - }, - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + override val prompt: String + get() = super.prompt + } - private fun schemaActor() = object : CodingActor( - interpreterClass = actor.interpreterClass, - symbols = mapOf(), - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + private fun schemaActor() = object : CodingActor( + interpreterClass = actor.interpreterClass, + symbols = mapOf(), + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - private fun displayCodeFeedback( - task: SessionTask, - actor: CodingActor, - request: CodingActor.CodeRequest, - response: CodeResult = execWrap { actor.answer(request, api = api) }, - onComplete: (String) -> Unit - ) { - task.hideable(ui, renderMarkdown("```kotlin\n${/*escapeHtml4*/(response.code)/*.indent(" ")*/}\n```", ui=ui)) - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + private fun displayCodeFeedback( + task: SessionTask, + actor: CodingActor, + request: CodingActor.CodeRequest, + response: CodeResult = execWrap { actor.answer(request, api = api) }, + onComplete: (String) -> Unit + ) { + task.hideable(ui, renderMarkdown("```kotlin\n${/*escapeHtml4*/(response.code)/*.indent(" ")*/}\n```", ui = ui)) + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${ - if (!super.canPlay) "" else - super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button"){ - super.responseAction(task, "Accepted...", formHandle!!, formText) { - onComplete(response.code) + if (!super.canPlay) "" else + super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button") { + super.responseAction(task, "Accepted...", formHandle!!, formText) { + onComplete(response.code) + } + } } - } - } |${ - super.ui.hrefLink("♻", "href-link regen-button"){ - super.responseAction(task, "Regenerating...", formHandle!!, formText) { - //val task = super.ui.newTask() - val codeRequest = - request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as OpenAIClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) - } - super.displayCode(task, codeResponse) - displayCodeFeedback(task, actor, super.append(codeRequest, codeResponse), codeResponse, onComplete) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button"){ - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) + super.ui.hrefLink("♻", "href-link regen-button") { + super.responseAction(task, "Regenerating...", formHandle!!, formText) { + //val task = super.ui.newTask() + val codeRequest = + request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as OpenAIClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) + } + super.displayCode(task, codeResponse) + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) + } + } + } } - } - } - } |
|${ - super.ui.textInput { feedback -> - super.responseAction(task, "Revising...", formHandle!!, formText) { - //val task = super.ui.newTask() - try { - task.echo(renderMarkdown(feedback, ui=ui)) - val codeRequest = CodingActor.CodeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() }.map { it.first to it.second } - ) - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as OpenAIClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) + super.ui.textInput { feedback -> + super.responseAction(task, "Revising...", formHandle!!, formText) { + //val task = super.ui.newTask() + try { + task.echo(renderMarkdown(feedback, ui = ui)) + val codeRequest = CodingActor.CodeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + feedback to ApiModel.Role.user, + ).filter { it.first.isNotBlank() }.map { it.first to it.second } + ) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as OpenAIClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) + } + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) + } + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) + } + } } - displayCodeFeedback(task, actor, super.append(codeRequest, codeResponse), codeResponse, onComplete) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button"){ - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) - } - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) } - } - } - } """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - class ServletBuffer : ArrayList() + class ServletBuffer : ArrayList() - private fun buildTestPage( - openAPI: OpenAPI, - servletImpl: String, - task: SessionTask - ) { - var testPage = SimpleActor( - prompt = "Given the definition for a servlet handler, create a test page that can be used to test the servlet", - model = model, - ).answer( - listOf( - JsonUtil.toJson(openAPI), - servletImpl - ), api - ) - // if ```html unwrap - if (testPage.contains("```html")) testPage = testPage.substringAfter("```html").substringBefore("```") - task.add(renderMarkdown("```html\n$testPage\n```", ui=ui)) - task.complete( - "Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" ) - }'>Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" - ) - } + } - abstract fun getInterpreterString(): String; + abstract fun getInterpreterString(): String - private fun answer( - actor: CodingActor, - request: CodingActor.CodeRequest, - task: SessionTask = ui.newTask(), - feedback: Boolean = true, - ): CodeResult { - val response = actor.answer(request, api = api) - if (feedback) displayCodeAndFeedback(task, request, response) - else displayCode(task, response) - return response - } + private fun answer( + actor: CodingActor, + request: CodingActor.CodeRequest, + task: SessionTask = ui.newTask(), + feedback: Boolean = true, + ): CodeResult { + val response = actor.answer(request, api = api) + if (feedback) displayCodeAndFeedback(task, request, response) + else displayCode(task, response) + return response + } - companion object { - val log = LoggerFactory.getLogger(ToolAgent::class.java) - fun execWrap(fn: () -> T) : T { - val classLoader = Thread.currentThread().contextClassLoader - val prevCL = KotlinInterpreter.classLoader - KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader - return try { - WebAppClassLoader.runWithServerClassAccess { - require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) - require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) - // com.simiacryptus.jopenai.OpenAIClient - require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) - require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) - fn() + companion object { + val log = LoggerFactory.getLogger(ToolAgent::class.java) + fun execWrap(fn: () -> T): T { + val classLoader = Thread.currentThread().contextClassLoader + val prevCL = KotlinInterpreter.classLoader + KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader + return try { + WebAppClassLoader.runWithServerClassAccess { + require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) + require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) + // com.simiacryptus.jopenai.OpenAIClient + require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) + require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) + fn() + } + } finally { + KotlinInterpreter.classLoader = prevCL + } } - } finally { - KotlinInterpreter.classLoader = prevCL - } } - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt index 86f81eaf..6011269f 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt @@ -24,82 +24,83 @@ import com.simiacryptus.skyenet.webui.session.SessionTask import com.simiacryptus.skyenet.webui.util.MarkdownUtil.renderMarkdown import org.slf4j.LoggerFactory import java.io.File +import java.nio.file.Path import java.util.concurrent.Semaphore import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference open class WebDevApp( - applicationName: String = "Web Dev Assistant v1.1", - open val symbols: Map = mapOf(), - val temperature: Double = 0.1, + applicationName: String = "Web Dev Assistant v1.1", + open val symbols: Map = mapOf(), + val temperature: Double = 0.1, ) : ApplicationServer( - applicationName = applicationName, - path = "/webdev", + applicationName = applicationName, + path = "/webdev", ) { - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - val settings = getSettings(session, user) ?: Settings() - (api as ClientManager.MonitoredClient).budget = settings.budget ?: 2.00 - WebDevAgent( - api = api, - dataStorage = dataStorage, - session = session, - user = user, - ui = ui, - tools = settings.tools, - model = settings.model, - ).start( - userMessage = userMessage, - ) - } + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + val settings = getSettings(session, user) ?: Settings() + (api as ClientManager.MonitoredClient).budget = settings.budget ?: 2.00 + WebDevAgent( + api = api, + dataStorage = dataStorage, + session = session, + user = user, + ui = ui, + tools = settings.tools, + model = settings.model, + ).start( + userMessage = userMessage, + ) + } - data class Settings( - val budget: Double? = 2.00, - val tools: List = emptyList(), - val model: ChatModels = ChatModels.GPT4Turbo, - ) + data class Settings( + val budget: Double? = 2.00, + val tools: List = emptyList(), + val model: ChatModels = ChatModels.GPT4Turbo, + ) - override val settingsClass: Class<*> get() = Settings::class.java + override val settingsClass: Class<*> get() = Settings::class.java - @Suppress("UNCHECKED_CAST") - override fun initSettings(session: Session): T? = Settings() as T + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T? = Settings() as T } class WebDevAgent( - val api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - val ui: ApplicationInterface, - val model: ChatModels, - val tools: List = emptyList(), - val actorMap: Map> = mapOf( - ActorTypes.HtmlCodingActor to SimpleActor( - prompt = """ + val api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + val ui: ApplicationInterface, + val model: ChatModels, + val tools: List = emptyList(), + val actorMap: Map> = mapOf( + ActorTypes.HtmlCodingActor to SimpleActor( + prompt = """ You will translate the user request into a skeleton HTML file for a rich javascript application. The html file can reference needed CSS and JS files, which are will be located in the same directory as the html file. Do not output the content of the resource files, only the html file. """.trimIndent(), model = model - ), - ActorTypes.JavascriptCodingActor to SimpleActor( - prompt = """ + ), + ActorTypes.JavascriptCodingActor to SimpleActor( + prompt = """ You will translate the user request into a javascript file for use in a rich javascript application. """.trimIndent(), model = model - ), - ActorTypes.CssCodingActor to SimpleActor( - prompt = """ + ), + ActorTypes.CssCodingActor to SimpleActor( + prompt = """ You will translate the user request into a CSS file for use in a rich javascript application. """.trimIndent(), model = model - ), - ActorTypes.ArchitectureDiscussionActor to ParsedActor( + ), + ActorTypes.ArchitectureDiscussionActor to ParsedActor( // parserClass = PageResourceListParser::class.java, - resultClass = PageResourceList::class.java, - prompt = """ + resultClass = PageResourceList::class.java, + prompt = """ Translate the user's idea into a detailed architecture for a simple web application. Suggest specific frameworks/libraries to import and provide CDN links for them. Specify user interactions and how the application will respond to them. @@ -107,11 +108,11 @@ class WebDevAgent( Identify coding styles and patterns to be used. List all files to be created, and for each file, describe the public interface / purpose / content summary. """.trimIndent(), - model = model, - parsingModel = model, - ), - ActorTypes.CodeReviewer to SimpleActor( - prompt = """ + model = model, + parsingModel = model, + ), + ActorTypes.CodeReviewer to SimpleActor( + prompt = """ Analyze the code summarized in the user's header-labeled code blocks. Review, look for bugs, and provide fixes. Provide implementations for missing functions. @@ -133,323 +134,321 @@ class WebDevAgent( Continued text """.trimIndent(), - model = model, + model = model, + ), ), - ), ) : ActorSystem(actorMap.map { it.key.name to it.value }.toMap(), dataStorage, user, session) { - enum class ActorTypes { - HtmlCodingActor, - JavascriptCodingActor, - CssCodingActor, - ArchitectureDiscussionActor, - CodeReviewer, - } - - private val architectureDiscussionActor by lazy { getActor(ActorTypes.ArchitectureDiscussionActor) as ParsedActor } - private val htmlActor by lazy { getActor(ActorTypes.HtmlCodingActor) as SimpleActor } - private val javascriptActor by lazy { getActor(ActorTypes.JavascriptCodingActor) as SimpleActor } - private val cssActor by lazy { getActor(ActorTypes.CssCodingActor) as SimpleActor } - private val codeReviewer by lazy { getActor(ActorTypes.CodeReviewer) as SimpleActor } - private val codeFiles = mutableMapOf() + enum class ActorTypes { + HtmlCodingActor, + JavascriptCodingActor, + CssCodingActor, + ArchitectureDiscussionActor, + CodeReviewer, + } - fun start( - userMessage: String, - ) { - val task = ui.newTask() - val toInput = { it: String -> listOf(it) } - val architectureResponse = Acceptable( - task = task, - userMessage = userMessage, - initialResponse = { it: String -> architectureDiscussionActor.answer(toInput(it), api = api) }, - outputFn = { design: ParsedResponse -> - // renderMarkdown("${design.text}\n\n```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```") - AgentPatterns.displayMapInTabs( - mapOf( - "Text" to renderMarkdown(design.text, ui=ui), - "JSON" to renderMarkdown("```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```", ui=ui), - ) - ) - }, - ui = ui, - reviseResponse = { userMessages: List> -> - architectureDiscussionActor.respond( - messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray()), - input = toInput(userMessage), - api = api - ) - }, - atomicRef = AtomicReference(), - semaphore = Semaphore(0), - heading = userMessage - ).call() + private val architectureDiscussionActor by lazy { getActor(ActorTypes.ArchitectureDiscussionActor) as ParsedActor } + private val htmlActor by lazy { getActor(ActorTypes.HtmlCodingActor) as SimpleActor } + private val javascriptActor by lazy { getActor(ActorTypes.JavascriptCodingActor) as SimpleActor } + private val cssActor by lazy { getActor(ActorTypes.CssCodingActor) as SimpleActor } + private val codeReviewer by lazy { getActor(ActorTypes.CodeReviewer) as SimpleActor } + private val codeFiles = mutableMapOf() - try { - val toolSpecs = tools.map { ToolServlet.tools.find { t -> t.path == it } } - .joinToString("\n\n") { it?.let { JsonUtil.toJson(it.openApiDescription) } ?: "" } - var messageWithTools = userMessage - if (toolSpecs.isNotBlank()) messageWithTools += "\n\nThese services are available:\n$toolSpecs" - task.echo(renderMarkdown("```json\n${JsonUtil.toJson(architectureResponse.obj)/*.indent(" ")*/}\n```", ui=ui)) - architectureResponse.obj.resources.filter { - !it.path!!.startsWith("http") - }.forEach { (path, description) -> + fun start( + userMessage: String, + ) { val task = ui.newTask() - when (path!!.split(".").last().lowercase()) { + val toInput = { it: String -> listOf(it) } + val architectureResponse = Acceptable( + task = task, + userMessage = userMessage, + initialResponse = { it: String -> architectureDiscussionActor.answer(toInput(it), api = api) }, + outputFn = { design: ParsedResponse -> + // renderMarkdown("${design.text}\n\n```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```") + AgentPatterns.displayMapInTabs( + mapOf( + "Text" to renderMarkdown(design.text, ui = ui), + "JSON" to renderMarkdown( + "```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```", + ui = ui + ), + ) + ) + }, + ui = ui, + reviseResponse = { userMessages: List> -> + architectureDiscussionActor.respond( + messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray()), + input = toInput(userMessage), + api = api + ) + }, + atomicRef = AtomicReference(), + semaphore = Semaphore(0), + heading = userMessage + ).call() - "js" -> draftResourceCode( - task, - javascriptActor.chatMessages( - listOf( - messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - javascriptActor, - path!!, "js", "javascript" - ) + try { + val toolSpecs = tools.map { ToolServlet.tools.find { t -> t.path == it } } + .joinToString("\n\n") { it?.let { JsonUtil.toJson(it.openApiDescription) } ?: "" } + var messageWithTools = userMessage + if (toolSpecs.isNotBlank()) messageWithTools += "\n\nThese services are available:\n$toolSpecs" + task.echo( + renderMarkdown( + "```json\n${JsonUtil.toJson(architectureResponse.obj)/*.indent(" ")*/}\n```", + ui = ui + ) + ) + architectureResponse.obj.resources.filter { + !it.path!!.startsWith("http") + }.forEach { (path, description) -> + val task = ui.newTask() + when (path!!.split(".").last().lowercase()) { - "css" -> draftResourceCode( - task, - cssActor.chatMessages( - listOf( - messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - cssActor, - path - ) + "js" -> draftResourceCode( + task, + javascriptActor.chatMessages( + listOf( + messageWithTools, + architectureResponse.text, + "Render $path - $description" + ) + ), + javascriptActor, + File(path).toPath(), "js", "javascript" + ) - "html" -> draftResourceCode( - task, - htmlActor.chatMessages( - listOf( - messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - htmlActor, - path - ) + "css" -> draftResourceCode( + task, + cssActor.chatMessages( + listOf( + messageWithTools, + architectureResponse.text, + "Render $path - $description" + ) + ), + cssActor, + File(path).toPath() + ) - else -> task.add("Resource Type Not Supported: $path - $description") - } - } - // Apply codeReviewer - fun codeSummary() = codeFiles.entries.joinToString("\n\n") { (path, code) -> - "# $path\n```${ - /*escapeHtml4*/(path.split('.').last())/*.indent(" ")*/ - }\n${/*escapeHtml4*/(code)/*.indent(" ")*/}\n```" - } + "html" -> draftResourceCode( + task, + htmlActor.chatMessages( + listOf( + messageWithTools, + architectureResponse.text, + "Render $path - $description" + ) + ), + htmlActor, + File(path).toPath() + ) - fun outputFn(task: SessionTask, design: String): StringBuilder? { - //val task = ui.newTask() - return task.complete( - ui.socketManager.addApplyFileDiffLinks( - root = codeFiles.keys.map { File(it).toPath() }.toTypedArray().commonRoot(), - code = codeFiles, - response = design, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - val prev = codeFiles[path] - if (prev != newCode) { - codeFiles[path] = newCode - task.complete( - "$path Updated" + else -> task.add("Resource Type Not Supported: $path - $description") + } + } + // Apply codeReviewer + fun codeSummary() = codeFiles.entries.joinToString("\n\n") { (path, code) -> + "# $path\n```${ + /*escapeHtml4*/(path.toString().split('.').last())/*.indent(" ")*/ + }\n${/*escapeHtml4*/(code)/*.indent(" ")*/}\n```" + } + + fun outputFn(task: SessionTask, design: String): StringBuilder? { + //val task = ui.newTask() + return task.complete( + ui.socketManager.addApplyFileDiffLinks( + root = codeFiles.keys.map { it }.toTypedArray().commonRoot(), + code = { codeFiles }, + response = design, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") + } + }, + ui = ui ) - } + ) + } + try { + var task = ui.newTask() + task.add(message = renderMarkdown(codeSummary(), ui = ui)) + var design = codeReviewer.answer(listOf(element = codeSummary()), api = api) + outputFn(task, design) + var textInputHandle: StringBuilder? = null + var textInput: String? = null + val feedbackGuard = AtomicBoolean(false) + textInput = ui.textInput { userResponse -> + if (feedbackGuard.getAndSet(true)) return@textInput + textInputHandle?.clear() + task.complete() + task = ui.newTask() + task.echo(renderMarkdown(userResponse, ui = ui)) + val codeSummary = codeSummary() + task.add(renderMarkdown(codeSummary, ui = ui)) + design = codeReviewer.respond( + messages = codeReviewer.chatMessages( + listOf( + codeSummary, + userResponse, + ) + ), + input = listOf(element = codeSummary), + api = api + ) + outputFn(task, design) + textInputHandle = task.complete(textInput!!) + feedbackGuard.set(false) } - }, - ui = ui - ) - ) - } - try { - var task = ui.newTask() - task.add(message = renderMarkdown(codeSummary(), ui=ui)) - var design = codeReviewer.answer(listOf(element = codeSummary()), api = api) - outputFn(task, design) - var textInputHandle: StringBuilder? = null - var textInput: String? = null - val feedbackGuard = AtomicBoolean(false) - textInput = ui.textInput { userResponse -> - if (feedbackGuard.getAndSet(true)) return@textInput - textInputHandle?.clear() - task.complete() - task = ui.newTask() - task.echo(renderMarkdown(userResponse, ui=ui)) - val codeSummary = codeSummary() - task.add(renderMarkdown(codeSummary, ui=ui)) - design = codeReviewer.respond( - messages = codeReviewer.chatMessages( - listOf( - codeSummary, - userResponse, - ) - ), - input = listOf(element = codeSummary), - api = api - ) - outputFn(task, design) - textInputHandle = task.complete(textInput!!) - feedbackGuard.set(false) - } - textInputHandle = task.complete(textInput) - } catch (e: Throwable) { - val task = ui.newTask() - task.error(ui = ui, e = e) - throw e - } + textInputHandle = task.complete(textInput) + } catch (e: Throwable) { + val task = ui.newTask() + task.error(ui = ui, e = e) + throw e + } - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) + } } - } - private fun draftResourceCode( - task: SessionTask, - request: Array, - actor: SimpleActor, - path: String, - vararg languages: String = arrayOf(path.split(".").last().lowercase()), - ) { - try { - var code = actor.respond(emptyList(), api, *request) - languages.forEach { language -> - if (code.contains("```$language")) code = code.substringAfter("```$language").substringBefore("```") - } - try { - task.add(renderMarkdown("```${languages.first()}\n$code\n```", ui=ui)) - task.add("$path Updated") - codeFiles[path] = code - val request1 = (request.toList() + - listOf( - ApiModel.ChatMessage(ApiModel.Role.assistant, code.toContentList()), - )).toTypedArray() - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + private fun draftResourceCode( + task: SessionTask, + request: Array, + actor: SimpleActor, + path: Path, + vararg languages: String = arrayOf(path.toString().split(".").last().lowercase()), + ) { + try { + var code = actor.respond(emptyList(), api, *request) + languages.forEach { language -> + if (code.contains("```$language")) code = code.substringAfter("```$language").substringBefore("```") + } + try { + task.add(renderMarkdown("```${languages.first()}\n$code\n```", ui = ui)) + task.add("$path Updated") + codeFiles[path] = code + val request1 = (request.toList() + + listOf( + ApiModel.ChatMessage(Role.assistant, code.toContentList()), + )).toTypedArray() + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${ - ui.hrefLink("♻", "href-link regen-button") { - val task = ui.newTask() - responseAction(task, "Regenerating...", formHandle!!, formText) { - draftResourceCode( - task, - request1.dropLastWhile { it.role == ApiModel.Role.assistant }.toTypedArray(), - actor, path, *languages - ) - } - } - } + ui.hrefLink("♻", "href-link regen-button") { + val task = ui.newTask() + responseAction(task, "Regenerating...", formHandle!!, formText) { + draftResourceCode( + task, + request1.dropLastWhile { it.role == Role.assistant } + .toTypedArray(), + actor, path, *languages + ) + } + } + } |
|${ - ui.textInput { feedback -> - responseAction(task, "Revising...", formHandle!!, formText) { - //val task = ui.newTask() - try { - task.echo(renderMarkdown(feedback, ui=ui)) - draftResourceCode( - task, (request1.toList() + listOf( - code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() } - .map { - ApiModel.ChatMessage( - it.second, - it.first.toContentList() - ) - }).toTypedArray(), actor, path, *languages - ) - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) - } - } - } - } + ui.textInput { feedback -> + responseAction(task, "Revising...", formHandle!!, formText) { + //val task = ui.newTask() + try { + task.echo(renderMarkdown(feedback, ui = ui)) + draftResourceCode( + task, (request1.toList() + listOf( + code to Role.assistant, + feedback to Role.user, + ).filter { it.first.isNotBlank() } + .map { + ApiModel.ChatMessage( + it.second, + it.first.toContentList() + ) + }).toTypedArray(), actor, path, *languages + ) + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) + } + } + } + } """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } catch (e: Throwable) { - task.error(ui, e) - log.warn("Error", e) - } - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(ui.hrefLink("♻", "href-link regen-button") { - regenButton?.clear() - val header = task.header("Regenerating...") - draftResourceCode(task, request, actor, path, *languages) - header?.clear() - error?.clear() - task.complete() - }) + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } catch (e: Throwable) { + task.error(ui, e) + log.warn("Error", e) + } + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + draftResourceCode(task, request, actor, path, *languages) + header?.clear() + error?.clear() + task.complete() + }) + } } - } - private fun responseAction( - task: SessionTask, - message: String, - formHandle: StringBuilder?, - formText: StringBuilder, - fn: () -> Unit = {} - ) { - formHandle?.clear() - val header = task.header(message) - try { - fn() - } finally { - header?.clear() - var revertButton: StringBuilder? = null - revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button") { - revertButton?.clear() - formHandle?.append(formText) - task.complete() - }) + private fun responseAction( + task: SessionTask, + message: String, + formHandle: StringBuilder?, + formText: StringBuilder, + fn: () -> Unit = {} + ) { + formHandle?.clear() + val header = task.header(message) + try { + fn() + } finally { + header?.clear() + var revertButton: StringBuilder? = null + revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button") { + revertButton?.clear() + formHandle?.append(formText) + task.complete() + }) + } } - } - companion object { - private val log = LoggerFactory.getLogger(WebDevAgent::class.java) + companion object { + private val log = LoggerFactory.getLogger(WebDevAgent::class.java) - data class PageResourceList( - @Description("List of resources in this project; don't forget the index.html file!") - val resources: List = emptyList() - ) : ValidatedObject { - override fun validate(): String? = when { - resources.isEmpty() -> "Resources are required" - resources.any { it.validate() != null } -> "Invalid resource" - else -> null - } - } + data class PageResourceList( + @Description("List of resources in this project; don't forget the index.html file!") + val resources: List = emptyList() + ) : ValidatedObject { + override fun validate(): String? = when { + resources.isEmpty() -> "Resources are required" + resources.any { it.validate() != null } -> "Invalid resource" + else -> null + } + } - data class PageResource( - val path: String? = "", - val description: String? = "" - ) : ValidatedObject { - override fun validate(): String? = when { - path.isNullOrBlank() -> "Path is required" - //path.contains(" ") -> "Path cannot contain spaces" - //!path.contains(".") -> "Path must contain a file extension" - else -> null - } - } + data class PageResource( + val path: String? = "", + val description: String? = "" + ) : ValidatedObject { + override fun validate(): String? = when { + path.isNullOrBlank() -> "Path is required" + //path.contains(" ") -> "Path cannot contain spaces" + //!path.contains(".") -> "Path must contain a file extension" + else -> null + } + } - } + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt index 103c941d..2662ba1c 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt @@ -4,48 +4,48 @@ import org.slf4j.LoggerFactory import java.util.concurrent.TimeUnit open class ProcessInterpreter( - private val defs: Map = mapOf(), + private val defs: Map = mapOf(), ) : Interpreter { - val command: List - get() = defs["command"]?.let { command -> - when (command) { - is String -> command.split(" ") - is List<*> -> command.map { it.toString() } - else -> throw IllegalArgumentException("Invalid command: $command") - } - } ?: listOf("bash") + val command: List + get() = defs["command"]?.let { command -> + when (command) { + is String -> command.split(" ") + is List<*> -> command.map { it.toString() } + else -> throw IllegalArgumentException("Invalid command: $command") + } + } ?: listOf("bash") - final override fun getLanguage(): String = defs["language"]?.toString() ?: "bash" - override fun getSymbols() = defs + final override fun getLanguage(): String = defs["language"]?.toString() ?: "bash" + override fun getSymbols() = defs - override fun validate(code: String): Throwable? { - // Always valid - return null - } + override fun validate(code: String): Throwable? { + // Always valid + return null + } - override fun run(code: String): Any? { - val wrappedCode = wrapCode(code.trim()) - val cmd = command.toTypedArray() - val cwd = defs["workingDir"]?.toString()?.let { java.io.File(it) } ?: java.io.File(".") - val processBuilder = ProcessBuilder(*cmd).directory(cwd) - defs["env"]?.let { env -> processBuilder.environment().putAll((env as Map)) } - val process = processBuilder.start() + override fun run(code: String): Any? { + val wrappedCode = wrapCode(code.trim()) + val cmd = command.toTypedArray() + val cwd = defs["workingDir"]?.toString()?.let { java.io.File(it) } ?: java.io.File(".") + val processBuilder = ProcessBuilder(*cmd).directory(cwd) + defs["env"]?.let { env -> processBuilder.environment().putAll((env as Map)) } + val process = processBuilder.start() - process.outputStream.write(wrappedCode.toByteArray()) - process.outputStream.close() - val output = process.inputStream.bufferedReader().readText() - val error = process.errorStream.bufferedReader().readText() + process.outputStream.write(wrappedCode.toByteArray()) + process.outputStream.close() + val output = process.inputStream.bufferedReader().readText() + val error = process.errorStream.bufferedReader().readText() - val waitFor = process.waitFor(5, TimeUnit.MINUTES) - if (!waitFor) { - process.destroy() - throw RuntimeException("Timeout; output: $output; error: $error") - } else if (error.isNotEmpty()) { - //throw RuntimeException(error) - return ( - """ + val waitFor = process.waitFor(5, TimeUnit.MINUTES) + if (!waitFor) { + process.destroy() + throw RuntimeException("Timeout; output: $output; error: $error") + } else if (error.isNotEmpty()) { + //throw RuntimeException(error) + return ( + """ |ERROR: |```text |$error @@ -56,12 +56,11 @@ open class ProcessInterpreter( |$output |``` """.trimMargin() - ) - } else { - return output + ) + } else { + return output + } } - } - companion object { - } + companion object } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt index ad86d7c6..e844a114 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt @@ -29,198 +29,199 @@ import kotlin.system.exitProcess abstract class ApplicationDirectory( - val localName: String = "localhost", - val publicName: String = "localhost", - val port: Int = 8081, + val localName: String = "localhost", + val publicName: String = "localhost", + val port: Int = 8081, ) { - var domainName: String = "" // Resolved in _main - private set - abstract val childWebApps: List - - data class ChildWebApp( - val path: String, - val server: ChatServer, - ) - - private fun domainName(isServer: Boolean) = - if (isServer) "https://$publicName" else "http://$localName:$port" - - open val welcomeResources = ResourceCollection(allResources("welcome").map(::newResource)) - open val userInfoServlet = UserInfoServlet() - open val userSettingsServlet = UserSettingsServlet() - open val logoutServlet = LogoutServlet() - open val usageServlet = UsageServlet() - open val proxyHttpServlet = ProxyHttpServlet() - open val apiKeyServlet = ApiKeyServlet() - open val welcomeServlet = WelcomeServlet(this) - abstract val toolServlet : ToolServlet? - - open fun authenticatedWebsite(): OAuthBase? = OAuthGoogle( - redirectUri = "$domainName/oauth2callback", - applicationName = "Demo", - key = { - val encryptedData = - javaClass.classLoader!!.getResourceAsStream("client_secret_google_oauth.json.kms")?.readAllBytes() - ?: throw RuntimeException("Unable to load resource: ${"client_secret_google_oauth.json.kms"}") - ApplicationServices.cloud!!.decrypt(encryptedData).byteInputStream() - } - ) + var domainName: String = "" // Resolved in _main + private set + abstract val childWebApps: List + + data class ChildWebApp( + val path: String, + val server: ChatServer, + ) + + private fun domainName(isServer: Boolean) = + if (isServer) "https://$publicName" else "http://$localName:$port" + + open val welcomeResources = ResourceCollection(allResources("welcome").map(::newResource)) + open val userInfoServlet = UserInfoServlet() + open val userSettingsServlet = UserSettingsServlet() + open val logoutServlet = LogoutServlet() + open val usageServlet = UsageServlet() + open val proxyHttpServlet = ProxyHttpServlet() + open val apiKeyServlet = ApiKeyServlet() + open val welcomeServlet = WelcomeServlet(this) + abstract val toolServlet: ToolServlet? + + open fun authenticatedWebsite(): OAuthBase? = OAuthGoogle( + redirectUri = "$domainName/oauth2callback", + applicationName = "Demo", + key = { + val encryptedData = + javaClass.classLoader!!.getResourceAsStream("client_secret_google_oauth.json.kms")?.readAllBytes() + ?: throw RuntimeException("Unable to load resource: ${"client_secret_google_oauth.json.kms"}") + ApplicationServices.cloud!!.decrypt(encryptedData).byteInputStream() + } + ) - open fun setupPlatform() { - ApplicationServices.seleniumFactory = { pool, cookies -> - Selenium2S3(pool, cookies,) + open fun setupPlatform() { + ApplicationServices.seleniumFactory = { pool, cookies -> + Selenium2S3(pool, cookies) + } } - } - - protected open fun _main(args: Array) { - try { - log.info("Starting application with args: ${args.joinToString(", ")}") - setupPlatform() - init(args.contains("--server")) - ClientUtil.keyTxt = run { + + protected open fun _main(args: Array) { try { - val encryptedData = javaClass.classLoader.getResourceAsStream("openai.key.json.kms")?.readAllBytes() - ?: throw RuntimeException("Unable to load resource: ${"openai.key.json.kms"}") - val decrypt = ApplicationServices.cloud!!.decrypt(encryptedData) - JsonUtil.fromJson(decrypt, Map::class.java) + log.info("Starting application with args: ${args.joinToString(", ")}") + setupPlatform() + init(args.contains("--server")) + ClientUtil.keyTxt = run { + try { + val encryptedData = javaClass.classLoader.getResourceAsStream("openai.key.json.kms")?.readAllBytes() + ?: throw RuntimeException("Unable to load resource: ${"openai.key.json.kms"}") + val decrypt = ApplicationServices.cloud!!.decrypt(encryptedData) + JsonUtil.fromJson(decrypt, Map::class.java) + } catch (e: Throwable) { + log.warn("Error loading key.txt", e) + "" + } + } + ApplicationServices.isLocked = true + val server = start( + port, + *(listOfNotNull( + newWebAppContext("/logout", logoutServlet), + newWebAppContext("/proxy", proxyHttpServlet), + toolServlet?.let { newWebAppContext("/tools", it) }, + newWebAppContext("/userInfo", userInfoServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/userSettings", userSettingsServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/usage", usageServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/apiKeys", apiKeyServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/", welcomeResources, "welcome", welcomeServlet).let { + authenticatedWebsite()?.configure(it, false) ?: it + }, + ).toTypedArray() + childWebApps.map { + newWebAppContext(it.path, it.server) + }) + ) + log.info("Server started successfully on port $port") + try { + Desktop.getDesktop().browse(URI("$domainName/")) + } catch (e: Throwable) { + // Ignore + } + server.join() } catch (e: Throwable) { - log.warn("Error loading key.txt", e) - "" + e.printStackTrace() + log.error("Application encountered an error: ${e.message}", e) + Thread.sleep(1000) + exitProcess(1) + } finally { + Thread.sleep(1000) + exitProcess(0) } - } - ApplicationServices.isLocked = true - val server = start( - port, - *(listOfNotNull( - newWebAppContext("/logout", logoutServlet), - newWebAppContext("/proxy", proxyHttpServlet), - toolServlet?.let { newWebAppContext("/tools", it) }, - newWebAppContext("/userInfo", userInfoServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/userSettings", userSettingsServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/usage", usageServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/apiKeys", apiKeyServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/", welcomeResources, "welcome", welcomeServlet).let { - authenticatedWebsite()?.configure(it, false) ?: it - }, - ).toTypedArray() + childWebApps.map { - newWebAppContext(it.path, it.server) - }) - ) - log.info("Server started successfully on port $port") - try { - Desktop.getDesktop().browse(URI("$domainName/")) - } catch (e: Throwable) { - // Ignore - } - server.join() - } catch (e: Throwable) { - e.printStackTrace() - log.error("Application encountered an error: ${e.message}", e) - Thread.sleep(1000) - exitProcess(1) - } finally { - Thread.sleep(1000) - exitProcess(0) } - } - - open fun init(isServer: Boolean): ApplicationDirectory { - OutputInterceptor.setupInterceptor() - log.info("Initializing application, isServer: $isServer") - domainName = domainName(isServer) - return this - } - - protected open fun start( - port: Int, - vararg webAppContexts: WebAppContext - ): Server { - val contexts = ContextHandlerCollection() + + open fun init(isServer: Boolean): ApplicationDirectory { + OutputInterceptor.setupInterceptor() + log.info("Initializing application, isServer: $isServer") + domainName = domainName(isServer) + return this + } + + protected open fun start( + port: Int, + vararg webAppContexts: WebAppContext + ): Server { + val contexts = ContextHandlerCollection() // val stats = StatisticsHandler() - log.info("Starting server on port: $port") - contexts.handlers = ( - listOf( - newWebAppContext("/stats", StatisticsServlet())) + - webAppContexts.map { - it.addFilter(FilterHolder(CorsFilter()), "/*", EnumSet.of(DispatcherType.REQUEST)) - it - } - ).toTypedArray() - val server = Server(port) - // Increase the number of acceptors and selectors for better scalability in a non-blocking model - val serverConnector = ServerConnector(server, 4, 8, httpConnectionFactory()) - serverConnector.port = port - serverConnector.acceptQueueSize = 1000 - serverConnector.idleTimeout = 30000 // Set idle timeout to 30 seconds - server.connectors = arrayOf(serverConnector) - server.handler = contexts - server.start() - if (!server.isStarted) throw IllegalStateException("Server failed to start") - log.info("Server initialization completed successfully.") - return server - } - - protected open fun httpConnectionFactory(): HttpConnectionFactory { - val httpConfig = HttpConfiguration() - httpConfig.addCustomizer(ForwardedRequestCustomizer()) - log.debug("HTTP connection factory created with custom configuration.") - return HttpConnectionFactory(httpConfig) - } - - protected open fun newWebAppContext(path: String, server: ChatServer): WebAppContext { - val baseResource = server.baseResource ?: throw IllegalStateException("No base resource") - val webAppContext = newWebAppContext(path, baseResource, resourceBase = "applicaton") - server.configure(webAppContext) - log.info("WebAppContext configured for path: $path with ChatServer") - return webAppContext - } - - protected open fun newWebAppContext( - path: String, - baseResource: Resource, - resourceBase: String, - indexServlet: Servlet? = null - ): WebAppContext { - val context = WebAppContext() - JettyWebSocketServletContainerInitializer.configure(context, null) - context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) - context.isParentLoaderPriority = true - context.baseResource = baseResource - log.debug("New WebAppContext created for path: $path") - context.contextPath = path - context.welcomeFiles = arrayOf("index.html") - if (indexServlet != null) { - context.addServlet(ServletHolder("$path/index", indexServlet), "/index.html") + log.info("Starting server on port: $port") + contexts.handlers = ( + listOf( + newWebAppContext("/stats", StatisticsServlet()) + ) + + webAppContexts.map { + it.addFilter(FilterHolder(CorsFilter()), "/*", EnumSet.of(DispatcherType.REQUEST)) + it + } + ).toTypedArray() + val server = Server(port) + // Increase the number of acceptors and selectors for better scalability in a non-blocking model + val serverConnector = ServerConnector(server, 4, 8, httpConnectionFactory()) + serverConnector.port = port + serverConnector.acceptQueueSize = 1000 + serverConnector.idleTimeout = 30000 // Set idle timeout to 30 seconds + server.connectors = arrayOf(serverConnector) + server.handler = contexts + server.start() + if (!server.isStarted) throw IllegalStateException("Server failed to start") + log.info("Server initialization completed successfully.") + return server + } + + protected open fun httpConnectionFactory(): HttpConnectionFactory { + val httpConfig = HttpConfiguration() + httpConfig.addCustomizer(ForwardedRequestCustomizer()) + log.debug("HTTP connection factory created with custom configuration.") + return HttpConnectionFactory(httpConfig) + } + + protected open fun newWebAppContext(path: String, server: ChatServer): WebAppContext { + val baseResource = server.baseResource ?: throw IllegalStateException("No base resource") + val webAppContext = newWebAppContext(path, baseResource, resourceBase = "applicaton") + server.configure(webAppContext) + log.info("WebAppContext configured for path: $path with ChatServer") + return webAppContext + } + + protected open fun newWebAppContext( + path: String, + baseResource: Resource, + resourceBase: String, + indexServlet: Servlet? = null + ): WebAppContext { + val context = WebAppContext() + JettyWebSocketServletContainerInitializer.configure(context, null) + context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) + context.isParentLoaderPriority = true + context.baseResource = baseResource + log.debug("New WebAppContext created for path: $path") + context.contextPath = path + context.welcomeFiles = arrayOf("index.html") + if (indexServlet != null) { + context.addServlet(ServletHolder("$path/index", indexServlet), "/index.html") + } + return context + } + + protected open fun newWebAppContext(path: String, servlet: Servlet): WebAppContext { + val context = WebAppContext() + JettyWebSocketServletContainerInitializer.configure(context, null) + context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) + context.isParentLoaderPriority = true + context.contextPath = path + log.debug("New WebAppContext created for servlet at path: $path") + context.resourceBase = "application" + context.welcomeFiles = arrayOf("index.html") + context.addServlet(ServletHolder(servlet), "/") + return context + } + + + companion object { + private val log = LoggerFactory.getLogger(ApplicationDirectory::class.java) + fun allResources(resourceName: String) = + Thread.currentThread().contextClassLoader.getResources(resourceName).toList() } - return context - } - - protected open fun newWebAppContext(path: String, servlet: Servlet): WebAppContext { - val context = WebAppContext() - JettyWebSocketServletContainerInitializer.configure(context, null) - context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) - context.isParentLoaderPriority = true - context.contextPath = path - log.debug("New WebAppContext created for servlet at path: $path") - context.resourceBase = "application" - context.welcomeFiles = arrayOf("index.html") - context.addServlet(ServletHolder(servlet), "/") - return context - } - - - companion object { - private val log = LoggerFactory.getLogger(ApplicationDirectory::class.java) - fun allResources(resourceName: String) = - Thread.currentThread().contextClassLoader.getResources(resourceName).toList() - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt index 44431ec1..58589c93 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt @@ -7,39 +7,39 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.function.Consumer open class ApplicationInterface(val socketManager: SocketManagerBase) { - @Description("Returns html for a link that will trigger the given handler when clicked.") - open fun hrefLink( - @Description("The text to display in the link") - linkText: String, - @Description("The css class to apply to the link") - classname: String = """href-link""", - @Description("The id to apply to the link") - id: String? = null, - @Description("The handler to trigger when the link is clicked") - handler: Consumer, - ) = socketManager.hrefLink(linkText, classname, id, oneAtATime(handler)) + @Description("Returns html for a link that will trigger the given handler when clicked.") + open fun hrefLink( + @Description("The text to display in the link") + linkText: String, + @Description("The css class to apply to the link") + classname: String = """href-link""", + @Description("The id to apply to the link") + id: String? = null, + @Description("The handler to trigger when the link is clicked") + handler: Consumer, + ) = socketManager.hrefLink(linkText, classname, id, oneAtATime(handler)) - @Description("Returns html for a text input form that will trigger the given handler when submitted.") - open fun textInput( - @Description("The handler to trigger when the form is submitted") - handler: Consumer - ): String = socketManager.textInput(oneAtATime(handler)) + @Description("Returns html for a text input form that will trigger the given handler when submitted.") + open fun textInput( + @Description("The handler to trigger when the form is submitted") + handler: Consumer + ): String = socketManager.textInput(oneAtATime(handler)) - @Description("Creates a new 'task' that can be used to display the progress of a long-running operation.") - open fun newTask( - //cancelable: Boolean = false - root: Boolean = true - ): SessionTask = socketManager.newTask(cancelable = false, root = root) + @Description("Creates a new 'task' that can be used to display the progress of a long-running operation.") + open fun newTask( + //cancelable: Boolean = false + root: Boolean = true + ): SessionTask = socketManager.newTask(cancelable = false, root = root) - companion object { - fun oneAtATime(handler: Consumer): Consumer { - val guard = AtomicBoolean(false) - return Consumer { t -> - if (guard.getAndSet(true)) return@Consumer - handler.accept(t) - guard.set(false) - } + companion object { + fun oneAtATime(handler: Consumer): Consumer { + val guard = AtomicBoolean(false) + return Consumer { t -> + if (guard.getAndSet(true)) return@Consumer + handler.accept(t) + guard.set(false) + } + } } - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt index 4fe7e71f..2d9b4dc9 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt @@ -6,49 +6,47 @@ import com.simiacryptus.skyenet.core.platform.Session import com.simiacryptus.skyenet.core.platform.StorageInterface import com.simiacryptus.skyenet.core.platform.User import com.simiacryptus.skyenet.webui.chat.ChatSocket -import com.simiacryptus.skyenet.webui.session.SessionTask import com.simiacryptus.skyenet.webui.session.SocketManagerBase abstract class ApplicationSocketManager( - session: Session, - owner: User?, - dataStorage: StorageInterface?, - applicationClass: Class<*>, + session: Session, + owner: User?, + dataStorage: StorageInterface?, + applicationClass: Class<*>, ) : SocketManagerBase( - session = session, - dataStorage = dataStorage, - owner = owner, - applicationClass = applicationClass, + session = session, + dataStorage = dataStorage, + owner = owner, + applicationClass = applicationClass, ) { - override fun onRun(userMessage: String, socket: ChatSocket) { - userMessage( - session = session, - user = socket.user, - userMessage = userMessage, - socketManager = this, - api = ApplicationServices.clientManager.getClient( - session, - socket.user, - dataStorage ?: throw IllegalStateException("No data storage") - ) - ) - } + override fun onRun(userMessage: String, socket: ChatSocket) { + userMessage( + session = session, + user = socket.user, + userMessage = userMessage, + socketManager = this, + api = ApplicationServices.clientManager.getClient( + session, + socket.user, + dataStorage ?: throw IllegalStateException("No data storage") + ) + ) + } - open val applicationInterface by lazy { ApplicationInterface(this) } + open val applicationInterface by lazy { ApplicationInterface(this) } - abstract fun userMessage( - session: Session, - user: User?, - userMessage: String, - socketManager: ApplicationSocketManager, - api: API - ) + abstract fun userMessage( + session: Session, + user: User?, + userMessage: String, + socketManager: ApplicationSocketManager, + api: API + ) - companion object { - val spinner: String get() = """
${SessionTask.spinner}
""" -// val playButton: String get() = """""" + companion object { + // val playButton: String get() = """""" // val cancelButton: String get() = """""" // val regenButton: String get() = """""" - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt index 42911815..1f087211 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt @@ -42,7 +42,8 @@ abstract class ChatServer(private val resourceBase: String) { if (sessions.containsKey(session)) { sessions[session]!! } else { - val user = authenticationManager.getUser(req.getCookie(AuthenticationInterface.AUTH_COOKIE)) + val user = + authenticationManager.getUser(req.getCookie(AuthenticationInterface.AUTH_COOKIE)) val sessionState = newSession(user, session) sessions[session] = sessionState sessionState @@ -70,7 +71,7 @@ abstract class ChatServer(private val resourceBase: String) { open fun configure(webAppContext: WebAppContext) { webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/default", defaultServlet), "/") webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/ws", webSocketHandler), "/ws") - webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/newSession", newSessionServlet),"/newSession") + webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/newSession", newSessionServlet), "/newSession") } companion object { diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt index 4c099c47..7f1444ee 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt @@ -35,8 +35,7 @@ class ChatSocket( sessionState.removeSocket(this) } - companion object { - } + companion object } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt index 5b48d556..42896994 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt @@ -14,67 +14,68 @@ import com.simiacryptus.skyenet.webui.session.SocketManagerBase import com.simiacryptus.skyenet.webui.util.MarkdownUtil open class ChatSocketManager( - session: Session, - val model: ChatModels, - val userInterfacePrompt: String, - open val initialAssistantPrompt: String = "", - open val systemPrompt: String, - val api: OpenAIClient, - val temperature: Double = 0.3, - applicationClass: Class, - val storage: StorageInterface?, + session: Session, + val model: ChatModels, + val userInterfacePrompt: String, + open val initialAssistantPrompt: String = "", + open val systemPrompt: String, + val api: OpenAIClient, + val temperature: Double = 0.3, + applicationClass: Class, + val storage: StorageInterface?, ) : SocketManagerBase(session, storage, owner = null, applicationClass = applicationClass) { - init { - if (userInterfacePrompt.isNotBlank()) { - send("""aaa,
${MarkdownUtil.renderMarkdown(userInterfacePrompt)}
""") + init { + if (userInterfacePrompt.isNotBlank()) { + send("""aaa,
${MarkdownUtil.renderMarkdown(userInterfacePrompt)}
""") + } } - } - protected val messages by lazy { - val list = listOf( - ApiModel.ChatMessage(ApiModel.Role.system, systemPrompt.toContentList()), - ).toMutableList() - if (initialAssistantPrompt.isNotBlank()) list += - ApiModel.ChatMessage(ApiModel.Role.assistant, initialAssistantPrompt.toContentList()) - list - } + protected val messages by lazy { + val list = listOf( + ApiModel.ChatMessage(ApiModel.Role.system, systemPrompt.toContentList()), + ).toMutableList() + if (initialAssistantPrompt.isNotBlank()) list += + ApiModel.ChatMessage(ApiModel.Role.assistant, initialAssistantPrompt.toContentList()) + list + } - @Synchronized - override fun onRun(userMessage: String, socket: ChatSocket) { - val task = newTask() - val responseContents = renderResponse(userMessage, task) - task.echo(responseContents) - messages += ApiModel.ChatMessage(ApiModel.Role.user, userMessage.toContentList()) - val messagesCopy = messages.toList() - try { - val ui = ApplicationInterface(this) - val process = { it: StringBuilder -> - val response = (api.chat( - ApiModel.ChatRequest( - messages = messagesCopy, - temperature = temperature, - model = model.modelName, - ), model - ).choices.first().message?.content.orEmpty()) - messages.dropLastWhile { it.role == ApiModel.Role.assistant } - messages += ApiModel.ChatMessage(ApiModel.Role.assistant, response.toContentList()) - val renderResponse = renderResponse(response, task) - onResponse(renderResponse, responseContents) - renderResponse - } - Retryable(ui, task, process).apply { set(label(size), process(container!!)) } - } catch (e: Exception) { - log.info("Error in chat", e) - task.error(ApplicationInterface(this), e) + @Synchronized + override fun onRun(userMessage: String, socket: ChatSocket) { + val task = newTask() + val responseContents = renderResponse(userMessage, task) + task.echo(responseContents) + messages += ApiModel.ChatMessage(ApiModel.Role.user, userMessage.toContentList()) + val messagesCopy = messages.toList() + try { + val ui = ApplicationInterface(this) + val process = { it: StringBuilder -> + val response = (api.chat( + ApiModel.ChatRequest( + messages = messagesCopy, + temperature = temperature, + model = model.modelName, + ), model + ).choices.first().message?.content.orEmpty()) + messages.dropLastWhile { it.role == ApiModel.Role.assistant } + messages += ApiModel.ChatMessage(ApiModel.Role.assistant, response.toContentList()) + val renderResponse = renderResponse(response, task) + onResponse(renderResponse, responseContents) + renderResponse + } + Retryable(ui, task, process).apply { set(label(size), process(container)) } + } catch (e: Exception) { + log.info("Error in chat", e) + task.error(ApplicationInterface(this), e) + } } - } - open fun renderResponse(response: String, task: SessionTask) = """
${MarkdownUtil.renderMarkdown(response)}
""" + open fun renderResponse(response: String, task: SessionTask) = + """
${MarkdownUtil.renderMarkdown(response)}
""" - open fun onResponse(response: String, responseContents: String) {} + open fun onResponse(response: String, responseContents: String) {} - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ChatSocketManager::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ChatSocketManager::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt index b538bcc3..9a2f688a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt @@ -19,142 +19,143 @@ import kotlin.reflect.typeOf class ApiKeyServlet : HttpServlet() { - data class ApiKeyRecord( - val owner: String, - val apiKey: String, - val mappedKey: String, - val budget: Double, - val comment: String, - val welcomeMessage: String = "Welcome to our service!" - ) + data class ApiKeyRecord( + val owner: String, + val apiKey: String, + val mappedKey: String, + val budget: Double, + val comment: String, + val welcomeMessage: String = "Welcome to our service!" + ) - override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { - // Log received parameters for debugging + override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { + // Log received parameters for debugging // println("Action: $action, API Key: $apiKey, Mapped Key: $mappedKey, Budget: $budget, Comment: $comment, User: ${user?.email}") - resp.contentType = "text/html" - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return resp.sendError( - HttpServletResponse.SC_UNAUTHORIZED - ) - val action = req.getParameter("action") - val apiKey = req.getParameter("apiKey") + resp.contentType = "text/html" + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return resp.sendError( + HttpServletResponse.SC_UNAUTHORIZED + ) + val action = req.getParameter("action") + val apiKey = req.getParameter("apiKey") - when (action?.toLowerCase(Locale.ROOT)) { - "edit" -> { - val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } - if (record != null) { - serveEditPage(resp, record) - } else { - resp.writer.write("API Key record not found") - } - } + when (action.lowercase(Locale.ROOT)) { + "edit" -> { + val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } + if (record != null) { + serveEditPage(resp, record) + } else { + resp.writer.write("API Key record not found") + } + } - "delete" -> { // Fix the null safety check consistency - val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user?.email } - if (record != null) { - apiKeyRecords.remove(record) - saveRecords() - resp.writer.write("API Key record deleted") - } else { - resp.writer.write("API Key record not found") - } - } + "delete" -> { // Fix the null safety check consistency + val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } + if (record != null) { + apiKeyRecords.remove(record) + saveRecords() + resp.writer.write("API Key record deleted") + } else { + resp.writer.write("API Key record not found") + } + } - "create" -> { - // Reuse the serveEditPage function but with an empty record for creation - serveEditPage( - resp, - ApiKeyRecord( - user.email, - UUID.randomUUID().toString(), - userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI] ?: "", // TODO: Expand support for other providers - 0.0, - "" - ) - ) - } + "create" -> { + // Reuse the serveEditPage function but with an empty record for creation + serveEditPage( + resp, + ApiKeyRecord( + user.email, + UUID.randomUUID().toString(), + userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI] + ?: "", // TODO: Expand support for other providers + 0.0, + "" + ) + ) + } - "invite" -> { - val record = apiKeyRecords.find { it.apiKey == apiKey /*&& it.owner != user.email*/ } - if (record == null) { - throw IllegalArgumentException("API Key record not found, or you do not have permission to access it, or you are the owner.") + "invite" -> { + val record = apiKeyRecords.find { it.apiKey == apiKey /*&& it.owner != user.email*/ } + if (record == null) { + throw IllegalArgumentException("API Key record not found, or you do not have permission to access it, or you are the owner.") + } + // Display a confirmation page instead of directly applying the settings + serveInviteConfirmationPage(resp, record, user) + } + + else -> { + resp.writer.write(indexPage(req)) + } } - // Display a confirmation page instead of directly applying the settings - serveInviteConfirmationPage(resp, record, user) - } - else -> { - resp.writer.write(indexPage(req)) - } } - } - - override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { - val action = req.getParameter("action") - val apiKey = req.getParameter("apiKey") - val mappedKey = req.getParameter("mappedKey") - val budget = req.getParameter("budget")?.toDoubleOrNull() - val comment = req.getParameter("comment") - // welcomeMessage - val welcomeMessage = req.getParameter("welcomeMessage") - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) - val record = apiKeyRecords.find { it.apiKey == apiKey } + override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { + val action = req.getParameter("action") + val apiKey = req.getParameter("apiKey") + val mappedKey = req.getParameter("mappedKey") + val budget = req.getParameter("budget")?.toDoubleOrNull() + val comment = req.getParameter("comment") + // welcomeMessage + val welcomeMessage = req.getParameter("welcomeMessage") + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) + val record = apiKeyRecords.find { it.apiKey == apiKey } - if (action == "acceptInvite") { - if (apiKey.isNullOrEmpty()) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "API Key is missing") - } else if (user == null) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "User not found") - } else if (record == null) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid API Key or User not found") - } else { - userSettingsManager.updateUserSettings( - user, userSettingsManager.getUserSettings(user).copy( - apiKeys = mapOf(APIProvider.OpenAI to apiKey), // TODO: Expand support for other providers - apiBase = mapOf(APIProvider.OpenAI to "https://apps.simiacrypt.us/proxy") - ) - ) - resp.sendRedirect("/") // Redirect to a success page or another relevant page - } - } else if (record != null && budget != null && user != null) { // Ensure user is not null before proceeding - apiKeyRecords.remove(record) - apiKeyRecords.add( - record.copy( - mappedKey = mappedKey ?: record.mappedKey, - budget = budget, - comment = comment ?: "" - ) - ) - saveRecords() - resp.sendRedirect("?action=edit&apiKey=$apiKey&editSuccess=true") - } else if (apiKey != null && budget != null) { - // Create a new record if apiKey is not found - val newRecord = ApiKeyRecord( - owner = user?.email ?: "", - apiKey = apiKey, - mappedKey = mappedKey ?: "", - budget = budget, - comment = comment ?: "", - welcomeMessage = welcomeMessage ?: "Welcome to our service!" - ) - apiKeyRecords.add(newRecord) - saveRecords() - resp.sendRedirect( - "?action=edit&apiKey=${ - URLEncoder.encode( - apiKey, - "UTF-8" - ) - }&creationSuccess=true" - ) // Encode apiKey to prevent URL manipulation - } else { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid input") + if (action == "acceptInvite") { + if (apiKey.isNullOrEmpty()) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "API Key is missing") + } else if (user == null) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "User not found") + } else if (record == null) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid API Key or User not found") + } else { + userSettingsManager.updateUserSettings( + user, userSettingsManager.getUserSettings(user).copy( + apiKeys = mapOf(APIProvider.OpenAI to apiKey), // TODO: Expand support for other providers + apiBase = mapOf(APIProvider.OpenAI to "https://apps.simiacrypt.us/proxy") + ) + ) + resp.sendRedirect("/") // Redirect to a success page or another relevant page + } + } else if (record != null && budget != null && user != null) { // Ensure user is not null before proceeding + apiKeyRecords.remove(record) + apiKeyRecords.add( + record.copy( + mappedKey = mappedKey ?: record.mappedKey, + budget = budget, + comment = comment ?: "" + ) + ) + saveRecords() + resp.sendRedirect("?action=edit&apiKey=$apiKey&editSuccess=true") + } else if (apiKey != null && budget != null) { + // Create a new record if apiKey is not found + val newRecord = ApiKeyRecord( + owner = user?.email ?: "", + apiKey = apiKey, + mappedKey = mappedKey ?: "", + budget = budget, + comment = comment ?: "", + welcomeMessage = welcomeMessage ?: "Welcome to our service!" + ) + apiKeyRecords.add(newRecord) + saveRecords() + resp.sendRedirect( + "?action=edit&apiKey=${ + URLEncoder.encode( + apiKey, + "UTF-8" + ) + }&creationSuccess=true" + ) // Encode apiKey to prevent URL manipulation + } else { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid input") + } } - } - private fun indexPage(req: HttpServletRequest): String { - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return "" - return """ + private fun indexPage(req: HttpServletRequest): String { + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return "" + return """ API Key Records @@ -169,21 +170,21 @@ class ApiKeyServlet : HttpServlet() {

API Key Records

${ - apiKeyRecords.filter { it.owner == user.email }.joinToString("\n") { record -> - "" - } - } + apiKeyRecords.filter { it.owner == user.email }.joinToString("\n") { record -> + "" + } + }
Create New API Key Record """.trimIndent() - } + } - private fun serveInviteConfirmationPage(resp: HttpServletResponse, record: ApiKeyRecord, user: User) { - //language=HTML - resp.writer.write( - """ + private fun serveInviteConfirmationPage(resp: HttpServletResponse, record: ApiKeyRecord, user: User) { + //language=HTML + resp.writer.write( + """ Accept API Key Invitation @@ -200,14 +201,14 @@ class ApiKeyServlet : HttpServlet() { """.trimIndent() - ) - } + ) + } - private fun serveEditPage(resp: HttpServletResponse, record: ApiKeyRecord) { - val usageSummary = ApplicationServices.usageManager.getUserUsageSummary(record.apiKey) - //language=HTML - resp.writer.write( - """ + private fun serveEditPage(resp: HttpServletResponse, record: ApiKeyRecord) { + val usageSummary = ApplicationServices.usageManager.getUserUsageSummary(record.apiKey) + //language=HTML + resp.writer.write( + """ Edit API Key Record: ${record.apiKey} @@ -274,16 +275,16 @@ class ApiKeyServlet : HttpServlet() {

Usage Summary

${ - usageSummary.entries.joinToString { (model: OpenAIModel, usage: ApiModel.Usage) -> - """ + usageSummary.entries.joinToString { (model: OpenAIModel, usage: ApiModel.Usage) -> + """

${model.modelName}

total_tokens: ${usage.total_tokens}

Cost: ${usage.cost}

""" - } - } + } + } | | """.trimMargin() ) - // (pool.threadFactory as RecordingThreadFactory).threads + // (pool.threadFactory as RecordingThreadFactory).threads } else { resp.status = HttpServletResponse.SC_BAD_REQUEST resp.writer.write("Session ID is required") diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ToolServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ToolServlet.kt index 2f286417..63ebed06 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ToolServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ToolServlet.kt @@ -26,15 +26,15 @@ import kotlin.reflect.typeOf abstract class ToolServlet(val app: ApplicationDirectory) : HttpServlet() { - data class Tool( - val path: String, - val openApiDescription: OpenAPI, - val interpreterString: String, - val servletCode: String, - ) + data class Tool( + val path: String, + val openApiDescription: OpenAPI, + val interpreterString: String, + val servletCode: String, + ) - @Language("HTML") - private fun indexPage() = """ + @Language("HTML") + private fun indexPage() = """ Tools @@ -67,8 +67,8 @@ abstract class ToolServlet(val app: ApplicationDirectory) : HttpServlet() { """.trimIndent() - @Language("HTML") - private fun toolDetailsPage(tool: Tool) = """ + @Language("HTML") + private fun toolDetailsPage(tool: Tool) = """ Tool Details @@ -126,9 +126,9 @@ abstract class ToolServlet(val app: ApplicationDirectory) : HttpServlet() { """.trimIndent() - private val header - @Language("HTML") - get() = """ + private val header + @Language("HTML") + get() = """ @@ -207,9 +207,9 @@ abstract class ToolServlet(val app: ApplicationDirectory) : HttpServlet() { """.trimIndent() - private fun serveEditPage(req: HttpServletRequest, resp: HttpServletResponse, tool: Tool) { - resp.contentType = "text/html" - val formHtml = """ + private fun serveEditPage(req: HttpServletRequest, resp: HttpServletResponse, tool: Tool) { + resp.contentType = "text/html" + val formHtml = """ Edit Tool: ${tool.path} @@ -232,193 +232,193 @@ abstract class ToolServlet(val app: ApplicationDirectory) : HttpServlet() { """.trimIndent() - resp.writer.write(formHtml) - resp.writer.close() - } + resp.writer.write(formHtml) + resp.writer.close() + } - override fun doGet(req: HttpServletRequest?, resp: HttpServletResponse?) { + override fun doGet(req: HttpServletRequest?, resp: HttpServletResponse?) { - val user = authenticationManager.getUser(req?.getCookie()) - if (!authorizationManager.isAuthorized(ToolServlet.javaClass, user, OperationType.Admin)) { - resp?.sendError(403) - return - } + val user = authenticationManager.getUser(req?.getCookie()) + if (!authorizationManager.isAuthorized(ToolServlet.javaClass, user, OperationType.Admin)) { + resp?.sendError(403) + return + } - resp?.contentType = "text/html" + resp?.contentType = "text/html" - val path = req?.getParameter("path") - if (req?.getParameter("edit") != null) { - val tool = tools.find { it.path == path } - if (tool != null) { - serveEditPage(req, resp!!, tool) - } else { - resp!!.writer.write("Tool not found") - } - return - } + val path = req?.getParameter("path") + if (req?.getParameter("edit") != null) { + val tool = tools.find { it.path == path } + if (tool != null) { + serveEditPage(req, resp!!, tool) + } else { + resp!!.writer.write("Tool not found") + } + return + } - if (req?.getParameter("delete") != null) { - val tool = tools.find { it.path == path } - if (tool != null) { - tools.remove(tool) - File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) - resp!!.sendRedirect("?") - } else { - resp!!.writer.write("Tool not found") - } - return - } + if (req?.getParameter("delete") != null) { + val tool = tools.find { it.path == path } + if (tool != null) { + tools.remove(tool) + File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) + resp!!.sendRedirect("?") + } else { + resp!!.writer.write("Tool not found") + } + return + } - if (req?.getParameter("export") != null) { - resp?.contentType = "application/json" - resp?.addHeader("Content-Disposition", "attachment; filename=\"tools.json\"") - resp?.writer?.write(JsonUtil.toJson(tools)) - return - } + if (req?.getParameter("export") != null) { + resp?.contentType = "application/json" + resp?.addHeader("Content-Disposition", "attachment; filename=\"tools.json\"") + resp?.writer?.write(JsonUtil.toJson(tools)) + return + } - if (path != null) { - // Display details for a single tool - val tool = tools.find { it.path == path } - if (tool != null) { - resp?.writer?.write(toolDetailsPage(tool)) - } else { - resp?.writer?.write("Tool not found") - } - } else { - // Display index page - resp?.writer?.write(indexPage()) + if (path != null) { + // Display details for a single tool + val tool = tools.find { it.path == path } + if (tool != null) { + resp?.writer?.write(toolDetailsPage(tool)) + } else { + resp?.writer?.write("Tool not found") + } + } else { + // Display index page + resp?.writer?.write(indexPage()) + } + resp?.writer?.close() } - resp?.writer?.close() - } - override fun doPost(req: HttpServletRequest?, resp: HttpServletResponse?) { - req ?: return - resp ?: return + override fun doPost(req: HttpServletRequest?, resp: HttpServletResponse?) { + req ?: return + resp ?: return - val path = req.getParameter("path") - val tool = tools.find { it.path == path } - if (tool != null) { - tools.remove(tool) - val newpath = req.getParameter("newpath") ?: req.getParameter("path") - tools.add( - tool.copy( - path = newpath, - interpreterString = req.getParameter("interpreterString"), - servletCode = req.getParameter("servletCode"), - openApiDescription = JsonUtil.fromJson(req.getParameter("openApiDescription"), OpenAPI::class.java) - ) - ) - File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) - resp.sendRedirect("?path=$newpath&editSuccess=true") // Redirect to the tool's detail page or an edit success page - } else { - if (req.getParameter("import") != null) { - val inputStream = req.getPart("file")?.inputStream - val toolsJson = inputStream?.bufferedReader().use { it?.readText() } - if (toolsJson != null) { - val importedTools: List = JsonUtil.fromJson(toolsJson, typeOf>().javaType) - tools.clear() - tools.addAll(importedTools) - File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) - resp.sendRedirect("?importSuccess=true") + val path = req.getParameter("path") + val tool = tools.find { it.path == path } + if (tool != null) { + tools.remove(tool) + val newpath = req.getParameter("newpath") ?: req.getParameter("path") + tools.add( + tool.copy( + path = newpath, + interpreterString = req.getParameter("interpreterString"), + servletCode = req.getParameter("servletCode"), + openApiDescription = JsonUtil.fromJson(req.getParameter("openApiDescription"), OpenAPI::class.java) + ) + ) + File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) + resp.sendRedirect("?path=$newpath&editSuccess=true") // Redirect to the tool's detail page or an edit success page } else { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid file or format") + if (req.getParameter("import") != null) { + val inputStream = req.getPart("file")?.inputStream + val toolsJson = inputStream?.bufferedReader().use { it?.readText() } + if (toolsJson != null) { + val importedTools: List = JsonUtil.fromJson(toolsJson, typeOf>().javaType) + tools.clear() + tools.addAll(importedTools) + File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) + resp.sendRedirect("?importSuccess=true") + } else { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid file or format") + } + return + } + resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Tool not found") } - return - } - resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Tool not found") } - } - companion object { - private val userRoot by lazy { - File( - File(ApplicationServices.dataStorageRoot, ".skyenet"), - "tools" - ).apply { mkdirs() } - } + companion object { + private val userRoot by lazy { + File( + File(ApplicationServices.dataStorageRoot, ".skyenet"), + "tools" + ).apply { mkdirs() } + } - @OptIn(ExperimentalStdlibApi::class) - val tools by lazy { - val file = File(userRoot, "tools.json") - if (file.exists()) try { - return@lazy JsonUtil.fromJson(file.readText(), typeOf>().javaType) - } catch (e: Throwable) { - e.printStackTrace() - } - mutableListOf() - } + @OptIn(ExperimentalStdlibApi::class) + val tools by lazy { + val file = File(userRoot, "tools.json") + if (file.exists()) try { + return@lazy JsonUtil.fromJson(file.readText(), typeOf>().javaType) + } catch (e: Throwable) { + e.printStackTrace() + } + mutableListOf() + } - fun addTool(element: Tool) { - tools += element - File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) - } + fun addTool(element: Tool) { + tools += element + File(userRoot, "tools.json").writeText(JsonUtil.toJson(tools)) + } - val apiKey = UUID.randomUUID().toString() - val instanceCache = mutableMapOf() + val apiKey = UUID.randomUUID().toString() + val instanceCache = mutableMapOf() - } + } - override fun service(req: HttpServletRequest?, resp: HttpServletResponse?) { - req ?: return - resp ?: return - val path = req.servletPath ?: "/" - val tool = tools.find { it.path == path } - if (tool != null) { - // TODO: Isolate tools per user - val user = authenticationManager.getUser(req.getCookie()) - val isAdmin = authorizationManager.isAuthorized( - ToolServlet.javaClass, user, OperationType.Admin - ) - val isHeaderAuth = apiKey == req.getHeader("Authorization")?.removePrefix("Bearer ") - if (!isAdmin && !isHeaderAuth) { - resp.sendError(403) - } else { - try { - val servlet = instanceCache.computeIfAbsent(tool) { construct(user!!, tool) } - servlet.service(req, resp) - } catch (e: RuntimeException) { - throw e - } catch (e: Throwable) { - throw RuntimeException(e) + override fun service(req: HttpServletRequest?, resp: HttpServletResponse?) { + req ?: return + resp ?: return + val path = req.servletPath ?: "/" + val tool = tools.find { it.path == path } + if (tool != null) { + // TODO: Isolate tools per user + val user = authenticationManager.getUser(req.getCookie()) + val isAdmin = authorizationManager.isAuthorized( + ToolServlet.javaClass, user, OperationType.Admin + ) + val isHeaderAuth = apiKey == req.getHeader("Authorization")?.removePrefix("Bearer ") + if (!isAdmin && !isHeaderAuth) { + resp.sendError(403) + } else { + try { + val servlet = instanceCache.computeIfAbsent(tool) { construct(user!!, tool) } + servlet.service(req, resp) + } catch (e: RuntimeException) { + throw e + } catch (e: Throwable) { + throw RuntimeException(e) + } + } + } else { + super.service(req, resp) } - } - } else { - super.service(req, resp) } - } - private fun construct(user: User, tool: Tool): HttpServlet { - val returnBuffer = ToolAgent.ServletBuffer() - val classLoader = Thread.currentThread().contextClassLoader - val prevCL = KotlinInterpreter.classLoader - KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader - try { - WebAppClassLoader.runWithServerClassAccess { - require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) - require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) - this.fromString(user, tool.interpreterString).let { (interpreterClass, symbols) -> - val effectiveSymbols = (symbols + mapOf( - "returnBuffer" to returnBuffer, - "json" to JsonUtil, - )).filterKeys { !it.isNullOrBlank() } - interpreterClass.getConstructor(Map::class.java).newInstance(effectiveSymbols).run(tool.servletCode) + private fun construct(user: User, tool: Tool): HttpServlet { + val returnBuffer = ToolAgent.ServletBuffer() + val classLoader = Thread.currentThread().contextClassLoader + val prevCL = KotlinInterpreter.classLoader + KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader + try { + WebAppClassLoader.runWithServerClassAccess { + require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) + require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) + this.fromString(user, tool.interpreterString).let { (interpreterClass, symbols) -> + val effectiveSymbols = (symbols + mapOf( + "returnBuffer" to returnBuffer, + "json" to JsonUtil, + )).filterKeys { !it.isNullOrBlank() } + interpreterClass.getConstructor(Map::class.java).newInstance(effectiveSymbols).run(tool.servletCode) + } + } + } finally { + KotlinInterpreter.classLoader = prevCL } - } - } finally { - KotlinInterpreter.classLoader = prevCL - } - val first = returnBuffer.first() - return first - } + val first = returnBuffer.first() + return first + } - abstract fun fromString(user: User, str: String): InterpreterAndTools + abstract fun fromString(user: User, str: String): InterpreterAndTools } data class InterpreterAndTools( - val interpreterClass: Class, - val symbols: Map = mapOf(), + val interpreterClass: Class, + val symbols: Map = mapOf(), ) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UsageServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UsageServlet.kt index 03d5f3ed..2abe563a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UsageServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UsageServlet.kt @@ -57,8 +57,8 @@ class UsageServlet : HttpServlet() { Cost ${ - usage.entries.joinToString("\n") { (model, count) -> - """ + usage.entries.joinToString("\n") { (model, count) -> + """ $model ${count.prompt_tokens} @@ -66,8 +66,8 @@ class UsageServlet : HttpServlet() { ${"%.4f".format(count.cost ?: 0.0)} """.trimIndent() - } } + } Total $totalPromptTokens @@ -80,8 +80,6 @@ class UsageServlet : HttpServlet() { """.trimIndent()) } - companion object { - - } + companion object } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UserSettingsServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UserSettingsServlet.kt index fc39f85a..433a7ed2 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UserSettingsServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/UserSettingsServlet.kt @@ -20,15 +20,19 @@ class UserSettingsServlet : HttpServlet() { } else { val settings = ApplicationServices.userSettingsManager.getUserSettings(userinfo) val visibleSettings = settings.copy( - apiKeys = settings.apiKeys.mapValues { when(it.value) { - "" -> "" - else -> mask - } }, - apiBase = settings.apiBase.mapValues { when(it.value) { - null -> "https://api.openai.com/v1" - "" -> "https://api.openai.com/v1" - else -> settings.apiBase[it.key]!! - } }, + apiKeys = settings.apiKeys.mapValues { + when (it.value) { + "" -> "" + else -> mask + } + }, + apiBase = settings.apiBase.mapValues { + when (it.value) { + null -> "https://api.openai.com/v1" + "" -> "https://api.openai.com/v1" + else -> settings.apiBase[it.key]!! + } + }, ) val json = JsonUtil.toJson(visibleSettings) //language=HTML @@ -60,22 +64,25 @@ class UserSettingsServlet : HttpServlet() { val settings = JsonUtil.fromJson(req.getParameter("settings"), UserSettings::class.java) val prevSettings = ApplicationServices.userSettingsManager.getUserSettings(userinfo) val reconstructedSettings = prevSettings.copy( - apiKeys = settings.apiKeys.mapValues { when(it.value) { - "" -> "" - mask -> prevSettings.apiKeys[it.key]!! - else -> settings.apiKeys[it.key]!! - } }, - apiBase = settings.apiBase.mapValues { when(it.value) { - null -> "https://api.openai.com/v1" - "" -> "https://api.openai.com/v1" - else -> settings.apiBase[it.key]!! - } }, + apiKeys = settings.apiKeys.mapValues { + when (it.value) { + "" -> "" + mask -> prevSettings.apiKeys[it.key]!! + else -> settings.apiKeys[it.key]!! + } + }, + apiBase = settings.apiBase.mapValues { + when (it.value) { + null -> "https://api.openai.com/v1" + "" -> "https://api.openai.com/v1" + else -> settings.apiBase[it.key]!! + } + }, ) ApplicationServices.userSettingsManager.updateUserSettings(userinfo, reconstructedSettings) resp.sendRedirect("/") } } - companion object { - } + companion object } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/WelcomeServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/WelcomeServlet.kt index f1c89552..40cdec30 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/WelcomeServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/WelcomeServlet.kt @@ -15,46 +15,47 @@ import jakarta.servlet.http.HttpServletResponse import org.intellij.lang.annotations.Language import java.nio.file.NoSuchFileException -open class WelcomeServlet(private val parent: com.simiacryptus.skyenet.webui.application.ApplicationDirectory) : - HttpServlet() { - override fun doGet(req: HttpServletRequest?, resp: HttpServletResponse?) { - val user = ApplicationServices.authenticationManager.getUser(req!!.getCookie()) - val requestURI = req.requestURI ?: "/" - resp?.contentType = when (requestURI) { - "/" -> "text/html" - else -> ApplicationServer.getMimeType(requestURI) - } - when { - requestURI == "/" -> resp?.writer?.write(homepage(user).trimIndent()) - requestURI == "/index.html" -> resp?.writer?.write(homepage(user).trimIndent()) - requestURI.startsWith("/userInfo") -> { - parent.userInfoServlet.doGet(req, resp!!) - } - else -> try { - val inputStream = parent.welcomeResources.addPath(requestURI)?.inputStream - inputStream?.copyTo(resp?.outputStream!!) - } catch (e: NoSuchFileException) { - resp?.sendError(404) - } +open class WelcomeServlet(private val parent: ApplicationDirectory) : + HttpServlet() { + override fun doGet(req: HttpServletRequest?, resp: HttpServletResponse?) { + val user = ApplicationServices.authenticationManager.getUser(req!!.getCookie()) + val requestURI = req.requestURI ?: "/" + resp?.contentType = when (requestURI) { + "/" -> "text/html" + else -> ApplicationServer.getMimeType(requestURI) + } + when { + requestURI == "/" -> resp?.writer?.write(homepage(user).trimIndent()) + requestURI == "/index.html" -> resp?.writer?.write(homepage(user).trimIndent()) + requestURI.startsWith("/userInfo") -> { + parent.userInfoServlet.doGet(req, resp!!) + } + + else -> try { + val inputStream = parent.welcomeResources.addPath(requestURI)?.inputStream + inputStream?.copyTo(resp?.outputStream!!) + } catch (e: NoSuchFileException) { + resp?.sendError(404) + } + } } - } - override fun doPost(req: HttpServletRequest?, resp: HttpServletResponse?) { - val requestURI = req?.requestURI ?: "/" - when { - requestURI.startsWith("/userSettings") -> parent.userSettingsServlet.doPost(req!!, resp!!) - else -> resp?.sendError(404) + override fun doPost(req: HttpServletRequest?, resp: HttpServletResponse?) { + val requestURI = req?.requestURI ?: "/" + when { + requestURI.startsWith("/userSettings") -> parent.userSettingsServlet.doPost(req!!, resp!!) + else -> resp?.sendError(404) + } } - } - @Language("Markdown") - protected open val welcomeMarkdown = """""".trimIndent() + @Language("Markdown") + protected open val welcomeMarkdown = """""".trimIndent() - @Language("Markdown") - protected open val postAppMarkdown = """""".trimIndent() + @Language("Markdown") + protected open val postAppMarkdown = """""".trimIndent() - @Language("HTML") - protected open fun homepage(user: User?) = """ + @Language("HTML") + protected open fun homepage(user: User?) = """ @@ -122,12 +123,12 @@ open class WelcomeServlet(private val parent: com.simiacryptus.skyenet.webui.app """.trimIndent() - protected open fun appRow( - app: ApplicationDirectory.ChildWebApp, - user: User? - ) = when { - !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Read) -> "" - else -> """ + protected open fun appRow( + app: ApplicationDirectory.ChildWebApp, + user: User? + ) = when { + !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Read) -> "" + else -> """ ${app.server.applicationName} @@ -137,23 +138,23 @@ open class WelcomeServlet(private val parent: com.simiacryptus.skyenet.webui.app ${ - when { - !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Public) -> "" - else -> - """New Public Session""" - } - } + when { + !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Public) -> "" + else -> + """New Public Session""" + } + } ${ - when { - !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Write) -> "" - else -> - """New Private Session""" - } - } + when { + !authorizationManager.isAuthorized(app.server.javaClass, user, OperationType.Write) -> "" + else -> + """New Private Session""" + } + } """.trimIndent() - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ZipServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ZipServlet.kt index d0d2306d..812807ca 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ZipServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ZipServlet.kt @@ -17,7 +17,8 @@ class ZipServlet(val dataStorage: StorageInterface) : HttpServlet() { val path = req.parameterMap.get("path")?.find { it.isNotBlank() } ?: "/" FileServlet.parsePath(path) // Validate path val sessionDir = dataStorage.getSessionDir( - ApplicationServices.authenticationManager.getUser(req.getCookie()), session) + ApplicationServices.authenticationManager.getUser(req.getCookie()), session + ) val file = File(sessionDir, path) val zipFile = File.createTempFile("skyenet", ".zip") try { diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SessionTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SessionTask.kt index 22872156..f01854cd 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SessionTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SessionTask.kt @@ -10,121 +10,123 @@ import org.slf4j.LoggerFactory import java.awt.image.BufferedImage abstract class SessionTask( - val operationID: String, - private var buffer: MutableList = mutableListOf(), - private val spinner: String = SessionTask.spinner + val operationID: String, + private var buffer: MutableList = mutableListOf(), + private val spinner: String = SessionTask.spinner ) { - val placeholder : String get() = "
" - - private val currentText: String - get() = buffer.filter { it.isNotBlank() }.joinToString("") - - private fun append( - htmlToAppend: String, - showSpinner: Boolean - ): StringBuilder? { - val stringBuilder: StringBuilder? - if (htmlToAppend.isNotBlank()) { - stringBuilder = StringBuilder("
$htmlToAppend
") - buffer += stringBuilder - } else { - stringBuilder = null + val placeholder: String get() = "
" + + private val currentText: String + get() = buffer.filter { it.isNotBlank() }.joinToString("") + + private fun append( + htmlToAppend: String, + showSpinner: Boolean + ): StringBuilder? { + val stringBuilder: StringBuilder? + if (htmlToAppend.isNotBlank()) { + stringBuilder = StringBuilder("
$htmlToAppend
") + buffer += stringBuilder + } else { + stringBuilder = null + } + send(currentText + if (showSpinner) "
$spinner
" else "") + return stringBuilder } - send(currentText + if (showSpinner) "
$spinner
" else "") - return stringBuilder - } - - protected abstract fun send( - html: String = currentText - ) - - @Description("Saves the given data to a file and returns the url of the file.") - abstract fun saveFile( - @Description("The name of the file to save") - relativePath: String, - @Description("The data to save") - data: ByteArray - ): String - - @Description("Adds a message to the task output.") - fun add( - @Description("The message to add") - message: String, - @Description("Whether to show the spinner for the task (default: true)") - showSpinner: Boolean = true, - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div", - @Description("The css class to apply to the message (default: response-message)") - className: String = "response-message" - ) = append("""<$tag class="$className">$message""", showSpinner) - - @Description("Adds a hideable message to the task output.") - fun hideable( - ui: ApplicationInterface?, - @Description("The message to add") - message: String, - @Description("Whether to show the spinner for the task (default: true)") - showSpinner: Boolean = true, - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div", - @Description("The css class to apply to the message (default: response-message)") - className: String = "response-message" - ): StringBuilder? { - var windowBuffer: StringBuilder? = null - val closeButton = """${ - ui?.hrefLink("×", "close-button href-link"){ - windowBuffer?.clear() - send() - } - }""" - windowBuffer = append("""<$tag class="$className">$closeButton$message""", showSpinner) - return windowBuffer - } - - @Description("Echos a user message to the task output.") - fun echo( - @Description("The message to echo") - message: String, - @Description("Whether to show the spinner for the task (default: true)") - showSpinner: Boolean = true, - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div" - ) = add(message, showSpinner, tag, "user-message") - - @Description("Adds a header to the task output.") - fun header( - @Description("The message to add") - message: String, - @Description("Whether to show the spinner for the task (default: true)") - showSpinner: Boolean = true, - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div", - classname: String = "response-header" - ) = add(message, showSpinner, tag, classname) - - @Description("Adds a verbose message to the task output; verbose messages are hidden by default.") - fun verbose( - @Description("The message to add") - message: String, - @Description("Whether to show the spinner for the task (default: true)") - showSpinner: Boolean = true, - @Description("The html tag to wrap the message in (default: pre)") - tag: String = "pre" - ) = add(message, showSpinner, tag, "verbose") - - @Description("Displays an error in the task output.") - fun error( - ui: ApplicationInterface?, - @Description("The error to display") - e: Throwable, - @Description("Whether to show the spinner for the task (default: false)") - showSpinner: Boolean = false, - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div" - ) = hideable(ui, - when { - e is ValidatedObject.ValidationError -> renderMarkdown( """ + + protected abstract fun send( + html: String = currentText + ) + + @Description("Saves the given data to a file and returns the url of the file.") + abstract fun saveFile( + @Description("The name of the file to save") + relativePath: String, + @Description("The data to save") + data: ByteArray + ): String + + @Description("Adds a message to the task output.") + fun add( + @Description("The message to add") + message: String, + @Description("Whether to show the spinner for the task (default: true)") + showSpinner: Boolean = true, + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div", + @Description("The css class to apply to the message (default: response-message)") + className: String = "response-message" + ) = append("""<$tag class="$className">$message""", showSpinner) + + @Description("Adds a hideable message to the task output.") + fun hideable( + ui: ApplicationInterface?, + @Description("The message to add") + message: String, + @Description("Whether to show the spinner for the task (default: true)") + showSpinner: Boolean = true, + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div", + @Description("The css class to apply to the message (default: response-message)") + className: String = "response-message" + ): StringBuilder? { + var windowBuffer: StringBuilder? = null + val closeButton = """${ + ui?.hrefLink("×", "close-button href-link") { + windowBuffer?.clear() + send() + } + }""" + windowBuffer = append("""<$tag class="$className">$closeButton$message""", showSpinner) + return windowBuffer + } + + @Description("Echos a user message to the task output.") + fun echo( + @Description("The message to echo") + message: String, + @Description("Whether to show the spinner for the task (default: true)") + showSpinner: Boolean = true, + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div" + ) = add(message, showSpinner, tag, "user-message") + + @Description("Adds a header to the task output.") + fun header( + @Description("The message to add") + message: String, + @Description("Whether to show the spinner for the task (default: true)") + showSpinner: Boolean = true, + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div", + classname: String = "response-header" + ) = add(message, showSpinner, tag, classname) + + @Description("Adds a verbose message to the task output; verbose messages are hidden by default.") + fun verbose( + @Description("The message to add") + message: String, + @Description("Whether to show the spinner for the task (default: true)") + showSpinner: Boolean = true, + @Description("The html tag to wrap the message in (default: pre)") + tag: String = "pre" + ) = add(message, showSpinner, tag, "verbose") + + @Description("Displays an error in the task output.") + fun error( + ui: ApplicationInterface?, + @Description("The error to display") + e: Throwable, + @Description("Whether to show the spinner for the task (default: false)") + showSpinner: Boolean = false, + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div" + ) = hideable( + ui, + when { + e is ValidatedObject.ValidationError -> renderMarkdown( + """ |**Data Validation Error** | |${e.message} @@ -134,10 +136,11 @@ abstract class SessionTask( |${e.stackTraceTxt/*.indent(" ")*/} |``` | - |""".trimMargin(), ui=ui - ) - e is CodingActor.FailedToImplementException -> renderMarkdown( - """ + |""".trimMargin(), ui = ui + ) + + e is CodingActor.FailedToImplementException -> renderMarkdown( + """ |**Failed to Implement** | |${e.message} @@ -152,57 +155,57 @@ abstract class SessionTask( |${/*escapeHtml4*/(e.code/*?.indent(" ")*/ ?: "")} |``` | - |""".trimMargin(), ui=ui - ) + |""".trimMargin(), ui = ui + ) - else -> renderMarkdown( - """ + else -> renderMarkdown( + """ |**Error `${e.javaClass.name}`** | |```text - |${e.stackTraceToString()?.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} + |${e.stackTraceToString().let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` - |""".trimMargin(), ui=ui - ) - }, showSpinner, tag, "error" - ) - - @Description("Displays a final message in the task output. This will hide the spinner.") - fun complete( - @Description("The message to display") - message: String = "", - @Description("The html tag to wrap the message in (default: div)") - tag: String = "div", - @Description("The css class to apply to the message (default: response-message)") - className: String = "response-message" - ) = append(if (message.isNotBlank()) """<$tag class="$className">$message""" else "", false) - - @Description("Displays an image to the task output.") - fun image( - @Description("The image to display") - image: BufferedImage - ) = add("""""") - - companion object { - val log = LoggerFactory.getLogger(SessionTask::class.java) - - const val spinner = - """
Loading...
""" - - fun BufferedImage.toPng(): ByteArray { - java.io.ByteArrayOutputStream().use { os -> - javax.imageio.ImageIO.write(this, "png", os) - return os.toByteArray() - } - } + |""".trimMargin(), ui = ui + ) + }, showSpinner, tag, "error" + ) + + @Description("Displays a final message in the task output. This will hide the spinner.") + fun complete( + @Description("The message to display") + message: String = "", + @Description("The html tag to wrap the message in (default: div)") + tag: String = "div", + @Description("The css class to apply to the message (default: response-message)") + className: String = "response-message" + ) = append(if (message.isNotBlank()) """<$tag class="$className">$message""" else "", false) + + @Description("Displays an image to the task output.") + fun image( + @Description("The image to display") + image: BufferedImage + ) = add("""""") + + companion object { + val log = LoggerFactory.getLogger(SessionTask::class.java) + + const val spinner = + """
Loading...
""" + + fun BufferedImage.toPng(): ByteArray { + java.io.ByteArrayOutputStream().use { os -> + javax.imageio.ImageIO.write(this, "png", os) + return os.toByteArray() + } + } - } + } } val Throwable.stackTraceTxt: String - get() { - val sw = java.io.StringWriter() - val pw = java.io.PrintWriter(sw) - printStackTrace(pw) - return sw.toString() - } + get() { + val sw = java.io.StringWriter() + val pw = java.io.PrintWriter(sw) + printStackTrace(pw) + return sw.toString() + } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt index 5284d1b8..af00afa7 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt @@ -14,259 +14,259 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.function.Consumer abstract class SocketManagerBase( - protected val session: Session, - protected val dataStorage: StorageInterface?, - protected val owner: User? = null, - private val messageStates: LinkedHashMap = dataStorage?.getMessages( - owner, session - ) ?: LinkedHashMap(), - private val applicationClass: Class<*>, + protected val session: Session, + protected val dataStorage: StorageInterface?, + protected val owner: User? = null, + private val messageStates: LinkedHashMap = dataStorage?.getMessages( + owner, session + ) ?: LinkedHashMap(), + private val applicationClass: Class<*>, ) : SocketManager { - private val sockets: MutableMap = mutableMapOf() - private val sendQueues: MutableMap> = mutableMapOf() - private val messageVersions = HashMap() - protected val pool get() = clientManager.getPool(session, owner, dataStorage) - val sendQueue = ConcurrentLinkedDeque() + private val sockets: MutableMap = mutableMapOf() + private val sendQueues: MutableMap> = mutableMapOf() + private val messageVersions = HashMap() + protected val pool get() = clientManager.getPool(session, owner, dataStorage) + val sendQueue = ConcurrentLinkedDeque() - override fun removeSocket(socket: ChatSocket) { - synchronized(sockets) { - sockets.remove(socket)?.close() + override fun removeSocket(socket: ChatSocket) { + synchronized(sockets) { + sockets.remove(socket)?.close() + } } - } - override fun addSocket(socket: ChatSocket, session: org.eclipse.jetty.websocket.api.Session) { - val user = getUser(session) - if (!ApplicationServices.authorizationManager.isAuthorized( - applicationClass = applicationClass, - user = user, - operationType = OperationType.Read - ) - ) throw IllegalArgumentException("Unauthorized") - synchronized(sockets) { - sockets[socket] = session + override fun addSocket(socket: ChatSocket, session: org.eclipse.jetty.websocket.api.Session) { + val user = getUser(session) + if (!ApplicationServices.authorizationManager.isAuthorized( + applicationClass = applicationClass, + user = user, + operationType = OperationType.Read + ) + ) throw IllegalArgumentException("Unauthorized") + synchronized(sockets) { + sockets[socket] = session + } } - } - private fun publish( - out: String, - ) { - synchronized(sockets) { - sockets.keys.forEach { chatSocket -> - try { - sendQueues.computeIfAbsent(chatSocket) { ConcurrentLinkedDeque() }.add(out) - } catch (e: Exception) { - log.info("Error sending message", e) - } - pool.submit { - try { - val deque = sendQueues[chatSocket]!! - synchronized(deque) { - while (true) { - val msg = deque.poll() ?: break - chatSocket.remote.sendString(msg) - } - chatSocket.remote.flush() + private fun publish( + out: String, + ) { + synchronized(sockets) { + sockets.keys.forEach { chatSocket -> + try { + sendQueues.computeIfAbsent(chatSocket) { ConcurrentLinkedDeque() }.add(out) + } catch (e: Exception) { + log.info("Error sending message", e) + } + pool.submit { + try { + val deque = sendQueues[chatSocket]!! + synchronized(deque) { + while (true) { + val msg = deque.poll() ?: break + chatSocket.remote.sendString(msg) + } + chatSocket.remote.flush() + } + } catch (e: Exception) { + log.info("Error sending message", e) + } + } } - } catch (e: Exception) { - log.info("Error sending message", e) - } } - } } - } - fun newTask( - cancelable: Boolean = false, - root : Boolean = true - ): SessionTask { - val operationID = randomID(root) - var responseContents = divInitializer(operationID, cancelable) - send(responseContents) - return SessionTaskImpl(operationID, responseContents, SessionTask.spinner) - } + fun newTask( + cancelable: Boolean = false, + root: Boolean = true + ): SessionTask { + val operationID = randomID(root) + var responseContents = divInitializer(operationID, cancelable) + send(responseContents) + return SessionTaskImpl(operationID, responseContents, SessionTask.spinner) + } - inner class SessionTaskImpl( - operationID: String, - responseContents: String, - spinner: String = SessionTask.spinner, - private val buffer: MutableList = mutableListOf(StringBuilder(responseContents)) - ) : SessionTask( - operationID = operationID, buffer = buffer, spinner = spinner - ) { + inner class SessionTaskImpl( + operationID: String, + responseContents: String, + spinner: String = SessionTask.spinner, + private val buffer: MutableList = mutableListOf(StringBuilder(responseContents)) + ) : SessionTask( + operationID = operationID, buffer = buffer, spinner = spinner + ) { - override fun send(html: String) = this@SocketManagerBase.send(html) - override fun saveFile(relativePath: String, data: ByteArray): String { - dataStorage?.getSessionDir(owner, session)?.let { dir -> - dir.mkdirs() - val resolve = dir.resolve(relativePath) - resolve.parentFile.mkdirs() - resolve.writeBytes(data) - } - return "fileIndex/$session/$relativePath" + override fun send(html: String) = this@SocketManagerBase.send(html) + override fun saveFile(relativePath: String, data: ByteArray): String { + dataStorage?.getSessionDir(owner, session)?.let { dir -> + dir.mkdirs() + val resolve = dir.resolve(relativePath) + resolve.parentFile.mkdirs() + resolve.writeBytes(data) + } + return "fileIndex/$session/$relativePath" + } } - } - fun send(out: String) { - try { - val split = out.split(',', ignoreCase = false, limit = 2) - val messageID = split[0] - val newValue = split[1] - if (setMessage(messageID, newValue) < 0) { - log.debug("Skipping duplicate message - Key: {}, Value: {} bytes", messageID, newValue.length) - return - } - if (sendQueue.contains(messageID)) { - log.debug("Skipping already queued message - Key: {}, Value: {} bytes", messageID, newValue.length) - return - } - if(0==out.length) { - log.debug("Skipping empty message - Key: {}, Value: {} bytes", messageID, newValue.length) - return - } - log.debug("Queue Send Msg: {} - {} - {} bytes", session, messageID, out.length) - sendQueue.add(messageID) - scheduledThreadPoolExecutor.schedule( - { - try { - while (sendQueue.isNotEmpty()) { - val messageID = sendQueue.poll() ?: return@schedule - val ver = messageVersions[messageID]?.get() - val v = messageStates[messageID] - log.debug("Wire Send Msg: {} - {} - {} - {} bytes", session, messageID, ver, v?.length) - publish(messageID + "," + ver + "," + v) + fun send(out: String) { + try { + val split = out.split(',', ignoreCase = false, limit = 2) + val messageID = split[0] + val newValue = split[1] + if (setMessage(messageID, newValue) < 0) { + log.debug("Skipping duplicate message - Key: {}, Value: {} bytes", messageID, newValue.length) + return + } + if (sendQueue.contains(messageID)) { + log.debug("Skipping already queued message - Key: {}, Value: {} bytes", messageID, newValue.length) + return } - } catch (e: Exception) { + if (0 == out.length) { + log.debug("Skipping empty message - Key: {}, Value: {} bytes", messageID, newValue.length) + return + } + log.debug("Queue Send Msg: {} - {} - {} bytes", session, messageID, out.length) + sendQueue.add(messageID) + scheduledThreadPoolExecutor.schedule( + { + try { + while (sendQueue.isNotEmpty()) { + val messageID = sendQueue.poll() ?: return@schedule + val ver = messageVersions[messageID]?.get() + val v = messageStates[messageID] + log.debug("Wire Send Msg: {} - {} - {} - {} bytes", session, messageID, ver, v?.length) + publish(messageID + "," + ver + "," + v) + } + } catch (e: Exception) { + log.debug("$session - $out", e) + } + }, + 50, java.util.concurrent.TimeUnit.MILLISECONDS + ) + } catch (e: Exception) { log.debug("$session - $out", e) - } - }, - 50, java.util.concurrent.TimeUnit.MILLISECONDS - ) - } catch (e: Exception) { - log.debug("$session - $out", e) + } } - } - final override fun getReplay(): List { - return messageStates.entries.map { - "${it.key},${messageVersions.computeIfAbsent(it.key) { AtomicInteger(1) }.get()},${it.value}" + final override fun getReplay(): List { + return messageStates.entries.map { + "${it.key},${messageVersions.computeIfAbsent(it.key) { AtomicInteger(1) }.get()},${it.value}" + } } - } - private fun setMessage(key: String, value: String): Int { - if (messageStates.containsKey(key)) { - if (messageStates[key] == value) { - return -1 - } + private fun setMessage(key: String, value: String): Int { + if (messageStates.containsKey(key)) { + if (messageStates[key] == value) { + return -1 + } + } + dataStorage?.updateMessage(owner, session, key, value) + messageStates.put(key, value) + val incrementAndGet = synchronized(messageVersions) + { messageVersions.getOrPut(key) { AtomicInteger(0) } }.incrementAndGet() + log.debug("Setting message - Key: {}, v{}, Value: {} bytes", key, incrementAndGet, value.length) + return incrementAndGet } - dataStorage?.updateMessage(owner, session, key, value) - messageStates.put(key, value) - val incrementAndGet = synchronized(messageVersions) - { messageVersions.getOrPut(key) { AtomicInteger(0) } }.incrementAndGet() - log.debug("Setting message - Key: {}, v{}, Value: {} bytes", key, incrementAndGet, value.length) - return incrementAndGet - } - final override fun onWebSocketText(socket: ChatSocket, message: String) { - if (canWrite(socket.user)) pool.submit { - log.debug("{} - Received message: {}", session, message) - try { - val opCmdPattern = """![a-z]{3,7},.*""".toRegex() - if (opCmdPattern.matches(message)) { - val id = message.substring(1, message.indexOf(",")) - val code = message.substring(id.length + 2) - onCmd(id, code) + final override fun onWebSocketText(socket: ChatSocket, message: String) { + if (canWrite(socket.user)) pool.submit { + log.debug("{} - Received message: {}", session, message) + try { + val opCmdPattern = """![a-z]{3,7},.*""".toRegex() + if (opCmdPattern.matches(message)) { + val id = message.substring(1, message.indexOf(",")) + val code = message.substring(id.length + 2) + onCmd(id, code) + } else { + onRun(message, socket) + } + } catch (e: Throwable) { + log.error("$session - Error processing message: $message", e) + send("""${randomID()},
${MarkdownUtil.renderMarkdown(e.message ?: "")}
""") + } } else { - onRun(message, socket) + log.warn("$session - Unauthorized message: $message") + send("""${randomID()},
Unauthorized message
""") } - } catch (e: Throwable) { - log.error("$session - Error processing message: $message", e) - send("""${randomID()},
${MarkdownUtil.renderMarkdown(e.message ?: "")}
""") - } - } else { - log.warn("$session - Unauthorized message: $message") - send("""${randomID()},
Unauthorized message
""") } - } - open fun canWrite(user: User?) = ApplicationServices.authorizationManager.isAuthorized( - applicationClass = applicationClass, - user = user, - operationType = OperationType.Write - ) + open fun canWrite(user: User?) = ApplicationServices.authorizationManager.isAuthorized( + applicationClass = applicationClass, + user = user, + operationType = OperationType.Write + ) - private val linkTriggers = mutableMapOf>() - private val txtTriggers = mutableMapOf>() - private fun onCmd(id: String, code: String) { - log.debug("Processing command - ID: {}, Code: {}", id, code) - if (code == "link") { - val consumer = linkTriggers[id] - consumer ?: throw IllegalArgumentException("No link handler found") - consumer.accept(Unit) - } else if (code.startsWith("userTxt,")) { - val consumer = txtTriggers[id] - consumer ?: throw IllegalArgumentException("No input handler found") - val text = code.substringAfter("userTxt,") - val unencoded = URLDecoder.decode(text, "UTF-8") - consumer.accept(unencoded) - } else { - throw IllegalArgumentException("Unknown command: $code") + private val linkTriggers = mutableMapOf>() + private val txtTriggers = mutableMapOf>() + private fun onCmd(id: String, code: String) { + log.debug("Processing command - ID: {}, Code: {}", id, code) + if (code == "link") { + val consumer = linkTriggers[id] + consumer ?: throw IllegalArgumentException("No link handler found") + consumer.accept(Unit) + } else if (code.startsWith("userTxt,")) { + val consumer = txtTriggers[id] + consumer ?: throw IllegalArgumentException("No input handler found") + val text = code.substringAfter("userTxt,") + val unencoded = URLDecoder.decode(text, "UTF-8") + consumer.accept(unencoded) + } else { + throw IllegalArgumentException("Unknown command: $code") + } } - } - fun hrefLink( - linkText: String, - classname: String = "href-link", - id: String? = null, - handler: Consumer - ): String { - val operationID = randomID() - linkTriggers[operationID] = handler - return """ """ id="$id"""" - else -> "" - } - }>$linkText""" - } + fun hrefLink( + linkText: String, + classname: String = "href-link", + id: String? = null, + handler: Consumer + ): String { + val operationID = randomID() + linkTriggers[operationID] = handler + return """ """ id="$id"""" + else -> "" + } + }>$linkText""" + } - fun textInput(handler: Consumer): String { - val operationID = randomID() - txtTriggers[operationID] = handler - //language=HTML - return """
+ fun textInput(handler: Consumer): String { + val operationID = randomID() + txtTriggers[operationID] = handler + //language=HTML + return """
""".trimIndent() - } + } - protected abstract fun onRun( - userMessage: String, - socket: ChatSocket, - ) + protected abstract fun onRun( + userMessage: String, + socket: ChatSocket, + ) - companion object { - private val log = LoggerFactory.getLogger(ChatServer::class.java) + companion object { + private val log = LoggerFactory.getLogger(ChatServer::class.java) - private val range1 = ('a'..'y').toList().toTypedArray() - private val range2 = range1 + 'z' - fun randomID(root : Boolean = true): String { - val random = java.util.Random() - val joinToString = (if (root) range1[random.nextInt(range1.size)] else "z").toString() + - (0..4).map { range2[random.nextInt(range2.size)] }.joinToString("") - return joinToString - } + private val range1 = ('a'..'y').toList().toTypedArray() + private val range2 = range1 + 'z' + fun randomID(root: Boolean = true): String { + val random = Random() + val joinToString = (if (root) range1[random.nextInt(range1.size)] else "z").toString() + + (0..4).map { range2[random.nextInt(range2.size)] }.joinToString("") + return joinToString + } - fun divInitializer(operationID: String = randomID(), cancelable: Boolean): String = - if (!cancelable) """$operationID,""" else - """$operationID,""" + fun divInitializer(operationID: String = randomID(), cancelable: Boolean): String = + if (!cancelable) """$operationID,""" else + """$operationID,""" - fun getUser(session: org.eclipse.jetty.websocket.api.Session): User? = - session.upgradeRequest.cookies?.find { it.name == AuthenticationInterface.AUTH_COOKIE }?.value.let { - ApplicationServices.authenticationManager.getUser(it) - } + fun getUser(session: org.eclipse.jetty.websocket.api.Session): User? = + session.upgradeRequest.cookies?.find { it.name == AuthenticationInterface.AUTH_COOKIE }?.value.let { + ApplicationServices.authenticationManager.getUser(it) + } - val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) - } + val scheduledThreadPoolExecutor = java.util.concurrent.Executors.newScheduledThreadPool(1) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/CodingActorTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/CodingActorTestApp.kt index dc0e8219..f455775d 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/CodingActorTestApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/CodingActorTestApp.kt @@ -20,8 +20,8 @@ open class CodingActorTestApp( applicationName: String = "CodingActorTest_" + actor.name, temperature: Double = 0.3, ) : ApplicationServer( - applicationName = applicationName, - path = "/codingActorTest", + applicationName = applicationName, + path = "/codingActorTest", ) { override fun userMessage( session: Session, @@ -30,26 +30,26 @@ open class CodingActorTestApp( ui: ApplicationInterface, api: API ) { - (api as ClientManager.MonitoredClient).budget = 2.00 + (api as ClientManager.MonitoredClient).budget = 2.00 val message = ui.newTask() try { - message.echo(renderMarkdown(userMessage, ui=ui)) + message.echo(renderMarkdown(userMessage, ui = ui)) val response = actor.answer(CodingActor.CodeRequest(listOf(userMessage to ApiModel.Role.user)), api = api) val canPlay = ApplicationServices.authorizationManager.isAuthorized( this::class.java, user, - OperationType.Execute + OperationType.Execute ) val playLink = if (!canPlay) "" else { - ui.hrefLink("▶", "href-link play-button"){ - message.add("Running...") - val result = response.result - message.complete( - """ + ui.hrefLink("▶", "href-link play-button") { + message.add("Running...") + val result = response.result + message.complete( + """ |
${result.resultValue}
|
${result.resultOutput}
""".trimMargin() - ) + ) } } message.complete( @@ -59,11 +59,10 @@ open class CodingActorTestApp( |${/*escapeHtml4*/(response.code)/*.indent(" ")*/} |``` |$playLink - """.trimMargin().trim(), ui=ui + """.trimMargin().trim(), ui = ui ) ) - } - catch (e: Throwable) { + } catch (e: Throwable) { log.warn("Error", e) message.error(ui, e) } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ImageActorTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ImageActorTestApp.kt index 94d7b0b0..ec471316 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ImageActorTestApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ImageActorTestApp.kt @@ -15,15 +15,17 @@ open class ImageActorTestApp( applicationName: String = "ImageActorTest_" + actor.javaClass.simpleName, temperature: Double = 0.3, ) : ApplicationServer( - applicationName = applicationName, + applicationName = applicationName, path = "/imageActorTest", ) { data class Settings( val actor: ImageActor? = null, ) + override val settingsClass: Class<*> get() = Settings::class.java - @Suppress("UNCHECKED_CAST") override fun initSettings(session: Session): T? = Settings(actor=actor) as T + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T? = Settings(actor = actor) as T override fun userMessage( session: Session, @@ -36,9 +38,10 @@ open class ImageActorTestApp( val message = ui.newTask() try { val actor = getSettings(session, user)?.actor ?: actor - message.echo(renderMarkdown(userMessage, ui=ui)) + message.echo(renderMarkdown(userMessage, ui = ui)) val response = actor.answer( - listOf(userMessage), api = api) + listOf(userMessage), api = api + ) message.verbose(response.text) message.image(response.image) message.complete() diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ParsedActorTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ParsedActorTestApp.kt index 3a207db7..b6dd8ba7 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ParsedActorTestApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/ParsedActorTestApp.kt @@ -16,7 +16,7 @@ open class ParsedActorTestApp( applicationName: String = "ParsedActorTest_" + actor.resultClass?.simpleName, temperature: Double = 0.3, ) : ApplicationServer( - applicationName = applicationName, + applicationName = applicationName, path = "/parsedActorTest", ) { override fun userMessage( @@ -29,7 +29,7 @@ open class ParsedActorTestApp( (api as ClientManager.MonitoredClient).budget = 2.00 val message = ui.newTask() try { - message.echo(renderMarkdown(userMessage, ui=ui)) + message.echo(renderMarkdown(userMessage, ui = ui)) val response = actor.answer(listOf(userMessage), api = api) message.complete( renderMarkdown( @@ -38,8 +38,8 @@ open class ParsedActorTestApp( |``` |${JsonUtil.toJson(response.obj)} |``` - """.trimMargin().trim(), ui=ui - ) + """.trimMargin().trim(), ui = ui + ) ) } catch (e: Throwable) { log.warn("Error", e) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/SimpleActorTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/SimpleActorTestApp.kt index 021ec0c4..7e4bbcd0 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/SimpleActorTestApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/SimpleActorTestApp.kt @@ -15,15 +15,17 @@ open class SimpleActorTestApp( applicationName: String = "SimpleActorTest_" + actor.javaClass.simpleName, temperature: Double = 0.3, ) : ApplicationServer( - applicationName = applicationName, + applicationName = applicationName, path = "/simpleActorTest", ) { data class Settings( val actor: SimpleActor? = null, ) + override val settingsClass: Class<*> get() = Settings::class.java - @Suppress("UNCHECKED_CAST") override fun initSettings(session: Session): T? = Settings(actor=actor) as T + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T? = Settings(actor = actor) as T override fun userMessage( session: Session, @@ -36,9 +38,9 @@ open class SimpleActorTestApp( val message = ui.newTask() try { val actor = getSettings(session, user)?.actor ?: actor - message.echo(renderMarkdown(userMessage, ui=ui)) + message.echo(renderMarkdown(userMessage, ui = ui)) val response = actor.answer(listOf(userMessage), api = api) - message.complete(renderMarkdown(response, ui=ui)) + message.complete(renderMarkdown(response, ui = ui)) } catch (e: Throwable) { log.warn("Error", e) message.error(ui, e) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt index 3055a7c4..c07fa65b 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt @@ -6,17 +6,17 @@ import java.nio.file.Paths object EncryptFiles { - @JvmStatic - fun main(args: Array) { - "".encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") - .write("""C:\Users\andre\code\SkyenetApps\src\main\resources\patreon.json.kms""") - } + @JvmStatic + fun main(args: Array) { + "".encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") + .write("""C:\Users\andre\code\SkyenetApps\src\main\resources\patreon.json.kms""") + } } fun String.write(outpath: String) { - Files.write(Paths.get(outpath), toByteArray()) + Files.write(Paths.get(outpath), toByteArray()) } fun String.encrypt(keyId: String) = ApplicationServices.cloud!!.encrypt(encodeToByteArray(), keyId) - ?: throw RuntimeException("Unable to encrypt data") \ No newline at end of file + ?: throw RuntimeException("Unable to encrypt data") \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/MarkdownUtil.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/MarkdownUtil.kt index b2bc468d..07a4bb68 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/MarkdownUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/MarkdownUtil.kt @@ -28,12 +28,11 @@ object MarkdownUtil { matches.forEach { match -> var mermaidCode = match.groups[1]!!.value // HTML Decode mermaidCode - mermaidCode = mermaidCode val fixedMermaidCode = fixupMermaidCode(mermaidCode) var mermaidDiagramHTML = """
$fixedMermaidCode
""" try { val svg = renderMermaidToSVG(fixedMermaidCode) - if(null != ui) { + if (null != ui) { val newTask = ui.newTask(false) newTask.complete(svg) mermaidDiagramHTML = newTask.placeholder @@ -100,7 +99,7 @@ object MarkdownUtil { if (errorOutput.isNotEmpty()) { log.error("Mermaid CLI Error: $errorOutput") } - if(svgContent.isNullOrBlank()) { + if (svgContent.isNullOrBlank()) { throw RuntimeException("Mermaid CLI failed to generate SVG") } return svgContent @@ -110,6 +109,7 @@ object MarkdownUtil { enum class State { DEFAULT, IN_NODE, IN_EDGE, IN_LABEL, IN_KEYWORD } + fun fixupMermaidCode(code: String): String { val stringBuilder = StringBuilder() var index = 0 @@ -127,19 +127,20 @@ object MarkdownUtil { currentState = State.IN_KEYWORD stringBuilder.append(code[index]) } else - if (code[index] == '[' || code[index] == '(' || code[index] == '{') { - // Possible start of a label - currentState = State.IN_LABEL - labelStart = index - } else if (code[index].isWhitespace() || code[index] == '-') { - // Continue in default state, possibly an edge - stringBuilder.append(code[index]) - } else { - // Start of a node - currentState = State.IN_NODE - stringBuilder.append(code[index]) - } + if (code[index] == '[' || code[index] == '(' || code[index] == '{') { + // Possible start of a label + currentState = State.IN_LABEL + labelStart = index + } else if (code[index].isWhitespace() || code[index] == '-') { + // Continue in default state, possibly an edge + stringBuilder.append(code[index]) + } else { + // Start of a node + currentState = State.IN_NODE + stringBuilder.append(code[index]) + } } + State.IN_KEYWORD -> { if (code[index].isWhitespace()) { // End of a keyword @@ -147,6 +148,7 @@ object MarkdownUtil { } stringBuilder.append(code[index]) } + State.IN_NODE -> { if (code[index] == '-' || code[index] == '>' || code[index].isWhitespace()) { // End of a node, start of an edge or space @@ -157,6 +159,7 @@ object MarkdownUtil { stringBuilder.append(code[index]) } } + State.IN_EDGE -> { if (!code[index].isWhitespace() && code[index] != '-' && code[index] != '>') { // End of an edge, start of a node @@ -167,6 +170,7 @@ object MarkdownUtil { stringBuilder.append(code[index]) } } + State.IN_LABEL -> { if (code[index] == ']' || code[index] == ')' || code[index] == '}') { // End of a label @@ -183,6 +187,7 @@ object MarkdownUtil { return stringBuilder.toString() } + private fun defaultOptions(): MutableDataSet { val options = MutableDataSet() options.set(Parser.EXTENSIONS, listOf(TablesExtension.create())) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/OpenAPI.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/OpenAPI.kt index f42de3d1..f7a3e99a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/OpenAPI.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/OpenAPI.kt @@ -6,103 +6,102 @@ import com.fasterxml.jackson.module.kotlin.readValue // OpenAPI root document data class OpenAPI( - val openapi: String = "3.0.0", - val info: Info? = null, - val paths: Map? = emptyMap(), - @JsonInclude(JsonInclude.Include.NON_NULL) - val components: Components? = null -) { -} + val openapi: String = "3.0.0", + val info: Info? = null, + val paths: Map? = emptyMap(), + @JsonInclude(JsonInclude.Include.NON_NULL) + val components: Components? = null +) // Metadata about the API data class Info( - val title: String? = null, - val version: String? = null, - val description: String? = null, - val termsOfService: String? = null, - val contact: Contact? = null, - val license: License? = null + val title: String? = null, + val version: String? = null, + val description: String? = null, + val termsOfService: String? = null, + val contact: Contact? = null, + val license: License? = null ) // Contact information data class Contact( - val name: String? = null, - val url: String? = null, - val email: String? = null + val name: String? = null, + val url: String? = null, + val email: String? = null ) // License information data class License( - val name: String? = null, - val url: String? = null + val name: String? = null, + val url: String? = null ) // Paths and operations data class PathItem( - val get: Operation? = null, - val put: Operation? = null, - val post: Operation? = null, - val delete: Operation? = null, - val options: Operation? = null, - val head: Operation? = null, - val patch: Operation? = null + val get: Operation? = null, + val put: Operation? = null, + val post: Operation? = null, + val delete: Operation? = null, + val options: Operation? = null, + val head: Operation? = null, + val patch: Operation? = null ) // An API operation data class Operation( - val summary: String? = null, - val description: String? = null, - val responses: Map? = emptyMap(), - val parameters: List? = emptyList(), - val operationId: String? = null, - val requestBody: RequestBody? = null, - val security: List>>? = emptyList(), - val tags: List? = emptyList(), - val callbacks: Map? = emptyMap(), - val deprecated: Boolean? = null, + val summary: String? = null, + val description: String? = null, + val responses: Map? = emptyMap(), + val parameters: List? = emptyList(), + val operationId: String? = null, + val requestBody: RequestBody? = null, + val security: List>>? = emptyList(), + val tags: List? = emptyList(), + val callbacks: Map? = emptyMap(), + val deprecated: Boolean? = null, ) // Operation response data class Response( - val description: String? = null, - @JsonInclude(JsonInclude.Include.NON_NULL) - val content: Map? = emptyMap() + val description: String? = null, + @JsonInclude(JsonInclude.Include.NON_NULL) + val content: Map? = emptyMap() ) // Components for reusable objects data class Components( - val schemas: Map? = emptyMap(), - val responses: Map? = emptyMap(), - val parameters: Map? = emptyMap(), - val examples: Map? = emptyMap(), - val requestBodies: Map? = emptyMap(), - val headers: Map? = emptyMap(), - val securitySchemes: Map? = emptyMap(), - val links: Map? = emptyMap(), - val callbacks: Map? = emptyMap() -) { -} + val schemas: Map? = emptyMap(), + val responses: Map? = emptyMap(), + val parameters: Map? = emptyMap(), + val examples: Map? = emptyMap(), + val requestBodies: Map? = emptyMap(), + val headers: Map? = emptyMap(), + val securitySchemes: Map? = emptyMap(), + val links: Map? = emptyMap(), + val callbacks: Map? = emptyMap() +) // Simplified examples of component objects data class Schema( - val type: String? = null, - val properties: Map? = emptyMap(), - val items: Schema? = null, - val `$ref`: String? = null, - val format: String? = null, - val description: String? = null, + val type: String? = null, + val properties: Map? = emptyMap(), + val items: Schema? = null, + val `$ref`: String? = null, + val format: String? = null, + val description: String? = null, - ) + ) data class Parameter( - val name: String? = null, - val `in`: String? = null, - val description: String? = null, - val required: Boolean? = null, - val schema: Schema? = null, - val content: Map? = null, - val example: Any? = null, + val name: String? = null, + val `in`: String? = null, + val description: String? = null, + val required: Boolean? = null, + val schema: Schema? = null, + val content: Map? = null, + val example: Any? = null, ) + data class Example(val summary: String? = null, val description: String? = null) data class RequestBody(val description: String? = null, val content: Map? = null) data class Header(val description: String? = null) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/Selenium2S3.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/Selenium2S3.kt index 7c6fed85..4623d8e9 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/Selenium2S3.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/Selenium2S3.kt @@ -28,436 +28,436 @@ import java.util.concurrent.Semaphore import java.util.concurrent.ThreadPoolExecutor open class Selenium2S3( - val pool: ThreadPoolExecutor = Executors.newCachedThreadPool() as ThreadPoolExecutor, - private val cookies: Array?, + val pool: ThreadPoolExecutor = Executors.newCachedThreadPool() as ThreadPoolExecutor, + private val cookies: Array?, ) : Selenium { - var loadImages: Boolean = false - open val driver: WebDriver by lazy { - chromeDriver(loadImages = loadImages).apply { - Companion.setCookies( - this, - cookies - ) - } - } - - private val httpClient by lazy { - HttpAsyncClientBuilder.create() - .useSystemProperties() - .setDefaultCookieStore(BasicCookieStore().apply { - cookies?.forEach { cookie -> addCookie(BasicClientCookie(cookie.name, cookie.value)) } - }) - .setThreadFactory(pool.threadFactory) - .build() - .also { it.start() } - } - - private val linkReplacements = mutableMapOf() - private val htmlPages: MutableMap = mutableMapOf() - private val jsonPages = mutableMapOf() - private val links: MutableList = mutableListOf() - - override fun save( - url: URL, - currentFilename: String?, - saveRoot: String - ) { - log.info("Saving URL: $url") - log.info("Current filename: $currentFilename") - log.info("Save root: $saveRoot") - driver.navigate().to(url) - driver.navigate().refresh() - Thread.sleep(5000) // Wait for javascript to load - - htmlPages += mutableMapOf((currentFilename ?: url.file.split("/").last()) to editPage(driver.pageSource)) - val baseUrl = url.toString().split("#").first() - links += toAbsolute(baseUrl, *currentPageLinks(driver).map { link -> - val relative = toRelative(baseUrl, link) ?: return@map link - linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" - linkReplacements[relative] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" - link - }.toTypedArray()).toMutableList() - val completionSemaphores = mutableListOf() - - log.info("Fetching page source") - log.info("Base URL: $baseUrl") - val coveredLinks = mutableSetOf() - log.info("Processing links") - while (links.isNotEmpty()) { - val href = links.removeFirst() - try { - if (coveredLinks.contains(href)) continue - coveredLinks += href - log.debug("Processing $href") - process(url, href, completionSemaphores, saveRoot) - } catch (e: Exception) { - log.warn("Error processing $href", e) - } + var loadImages: Boolean = false + open val driver: WebDriver by lazy { + chromeDriver(loadImages = loadImages).apply { + Companion.setCookies( + this, + cookies + ) + } } - log.info("Fetching current page links") - log.debug("Waiting for completion") - completionSemaphores.forEach { it.acquire(); it.release() } - - log.debug("Saving") - saveAll(saveRoot) - log.debug("Done") - } - - open protected fun process( - url: URL, - href: String, - completionSemaphores: MutableList, - saveRoot: String - ): Boolean { - val base = url.toString().split("/").dropLast(1).joinToString("/") - val relative = toArchivePath(toRelative(base, href) ?: return true) - when (val mimeType = mimeType(relative)) { - - "text/html" -> { - if (htmlPages.containsKey(relative)) return true - log.info("Fetching $href") - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getHtml(href, htmlPages, relative, links, saveRoot, semaphore) - } - - "application/json" -> { - if (jsonPages.containsKey(relative)) return true - log.info("Fetching $href") - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getJson(href, jsonPages, relative, semaphore) - } - - else -> { - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getMedia(href, mimeType, saveRoot, relative, semaphore) - } + private val httpClient by lazy { + HttpAsyncClientBuilder.create() + .useSystemProperties() + .setDefaultCookieStore(BasicCookieStore().apply { + cookies?.forEach { cookie -> addCookie(BasicClientCookie(cookie.name, cookie.value)) } + }) + .setThreadFactory(pool.threadFactory) + .build() + .also { it.start() } } - return false - } - - open protected fun getHtml( - href: String, - htmlPages: MutableMap, - relative: String, - links: MutableList, - saveRoot: String, - semaphore: Semaphore - ) { - httpClient.execute(get(href), object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - log.debug("Fetched $href") - val html = p0?.body?.bodyText ?: "" - htmlPages[relative] = html - links += toAbsolute(href, *currentPageLinks(html).map { link -> - val relative = toArchivePath(toRelative(href, link) ?: return@map link) - linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/$relative" - link - }.toTypedArray()) - semaphore.release() - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - open protected fun getJson( - href: String, - jsonPages: MutableMap, - relative: String, - semaphore: Semaphore - ) { - httpClient.execute(get(href), object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - log.debug("Fetched $href") - jsonPages[relative] = p0?.body?.bodyText ?: "" - semaphore.release() - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - open protected fun getMedia( - href: String, - mimeType: String, - saveRoot: String, - relative: String, - semaphore: Semaphore - ) { - val request = get(href) - httpClient.execute(request, object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - try { - log.debug("Fetched $request") - val bytes = p0?.body?.bodyBytes ?: return - if (validate(mimeType, p0.body.contentType.mimeType, bytes)) - cloud!!.upload( - path = "/$saveRoot/$relative", - contentType = mimeType, - bytes = bytes - ) - } finally { - semaphore.release() + + private val linkReplacements = mutableMapOf() + private val htmlPages: MutableMap = mutableMapOf() + private val jsonPages = mutableMapOf() + private val links: MutableList = mutableListOf() + + override fun save( + url: URL, + currentFilename: String?, + saveRoot: String + ) { + log.info("Saving URL: $url") + log.info("Current filename: $currentFilename") + log.info("Save root: $saveRoot") + driver.navigate().to(url) + driver.navigate().refresh() + Thread.sleep(5000) // Wait for javascript to load + + htmlPages += mutableMapOf((currentFilename ?: url.file.split("/").last()) to editPage(driver.pageSource)) + val baseUrl = url.toString().split("#").first() + links += toAbsolute(baseUrl, *currentPageLinks(driver).map { link -> + val relative = toRelative(baseUrl, link) ?: return@map link + linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" + linkReplacements[relative] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" + link + }.toTypedArray()).toMutableList() + val completionSemaphores = mutableListOf() + + log.info("Fetching page source") + log.info("Base URL: $baseUrl") + val coveredLinks = mutableSetOf() + log.info("Processing links") + while (links.isNotEmpty()) { + val href = links.removeFirst() + try { + if (coveredLinks.contains(href)) continue + coveredLinks += href + log.debug("Processing $href") + process(url, href, completionSemaphores, saveRoot) + } catch (e: Exception) { + log.warn("Error processing $href", e) + } } - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - private fun saveAll( - saveRoot: String - ) { - (htmlPages.map { (filename, html) -> - pool.submit { - try { - saveHTML(html, saveRoot, filename) - } catch (e: Exception) { - log.warn("Error processing $filename", e) + + log.info("Fetching current page links") + log.debug("Waiting for completion") + completionSemaphores.forEach { it.acquire(); it.release() } + + log.debug("Saving") + saveAll(saveRoot) + log.debug("Done") + } + + protected open fun process( + url: URL, + href: String, + completionSemaphores: MutableList, + saveRoot: String + ): Boolean { + val base = url.toString().split("/").dropLast(1).joinToString("/") + val relative = toArchivePath(toRelative(base, href) ?: return true) + when (val mimeType = mimeType(relative)) { + + "text/html" -> { + if (htmlPages.containsKey(relative)) return true + log.info("Fetching $href") + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getHtml(href, htmlPages, relative, links, saveRoot, semaphore) + } + + "application/json" -> { + if (jsonPages.containsKey(relative)) return true + log.info("Fetching $href") + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getJson(href, jsonPages, relative, semaphore) + } + + else -> { + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getMedia(href, mimeType, saveRoot, relative, semaphore) + } } - } - } + jsonPages.map { (filename, js) -> - pool.submit { - try { - saveJS(js, saveRoot, filename) - } catch (e: Exception) { - log.warn("Error processing $filename", e) + return false + } + + protected open fun getHtml( + href: String, + htmlPages: MutableMap, + relative: String, + links: MutableList, + saveRoot: String, + semaphore: Semaphore + ) { + httpClient.execute(get(href), object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + log.debug("Fetched $href") + val html = p0?.body?.bodyText ?: "" + htmlPages[relative] = html + links += toAbsolute(href, *currentPageLinks(html).map { link -> + val relative = toArchivePath(toRelative(href, link) ?: return@map link) + linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/$relative" + link + }.toTypedArray()) + semaphore.release() + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + protected open fun getJson( + href: String, + jsonPages: MutableMap, + relative: String, + semaphore: Semaphore + ) { + httpClient.execute(get(href), object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + log.debug("Fetched $href") + jsonPages[relative] = p0?.body?.bodyText ?: "" + semaphore.release() + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + protected open fun getMedia( + href: String, + mimeType: String, + saveRoot: String, + relative: String, + semaphore: Semaphore + ) { + val request = get(href) + httpClient.execute(request, object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + try { + log.debug("Fetched $request") + val bytes = p0?.body?.bodyBytes ?: return + if (validate(mimeType, p0.body.contentType.mimeType, bytes)) + cloud!!.upload( + path = "/$saveRoot/$relative", + contentType = mimeType, + bytes = bytes + ) + } finally { + semaphore.release() + } + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + private fun saveAll( + saveRoot: String + ) { + (htmlPages.map { (filename, html) -> + pool.submit { + try { + saveHTML(html, saveRoot, filename) + } catch (e: Exception) { + log.warn("Error processing $filename", e) + } + } + } + jsonPages.map { (filename, js) -> + pool.submit { + try { + saveJS(js, saveRoot, filename) + } catch (e: Exception) { + log.warn("Error processing $filename", e) + } + } + }).forEach { + try { + it.get() + } catch (e: Exception) { + log.warn("Error processing", e) + } } - } - }).forEach { - try { - it.get() - } catch (e: Exception) { - log.warn("Error processing", e) - } } - } - - open protected fun saveJS(js: String, saveRoot: String, filename: String) { - val finalJs = linkReplacements.toList().sortedBy { it.first.length } - .fold(js) { acc, (href, relative) -> //language=RegExp - acc.replace("""(? acc.replace("""(? - request.addHeader("Cookie", "${cookie.name}=${cookie.value}") + + protected open fun saveJS(js: String, saveRoot: String, filename: String) { + val finalJs = linkReplacements.toList().sortedBy { it.first.length } + .fold(js) { acc, (href, relative) -> //language=RegExp + acc.replace("""(? { it?.getAttribute("href") }.toSet(), - driver.findElements(By.xpath("//img[@src]")).map { it?.getAttribute("src") }.toSet(), - driver.findElements(By.xpath("//link[@href]")).map { it?.getAttribute("href") }.toSet(), - driver.findElements(By.xpath("//script[@src]")).map { it?.getAttribute("src") }.toSet(), - driver.findElements(By.xpath("//source[@src]")).map { it?.getAttribute("src") }.toSet(), - ).flatten().filterNotNull() - - private fun currentPageLinks(html: String) = listOf( - Jsoup.parse(html).select("a[href]").map { it.attr("href") }.toSet(), - Jsoup.parse(html).select("img[src]").map { it.attr("src") }.toSet(), - Jsoup.parse(html).select("link[href]").map { it.attr("href") }.toSet(), - Jsoup.parse(html).select("script[src]").map { it.attr("src") }.toSet(), - Jsoup.parse(html).select("source[src]").map { it.attr("src") }.toSet(), - ).flatten().filterNotNull() - - protected open fun toAbsolute(base: String, vararg links: String) = links - .map { it.split("#").first() }.filter { it.isNotBlank() }.distinct() - .map { link -> - val newLink = when { - link.startsWith("http") -> link - else -> URI.create(base).resolve(link).toString() - } - newLink + + protected open fun saveHTML(html: String, saveRoot: String, filename: String) { + val finalHtml = linkReplacements.toList().filter { it.first.isNotEmpty() }.fold(html) + { acc, (href, relative) -> acc.replace("""(? toRelative( - base, - link.removePrefix(base).replace("/{2,}".toRegex(), "/").removePrefix("/") - ) // relativize - link.startsWith("http") -> null // absolute - else -> link // relative - } - - protected open fun toArchivePath(link: String): String = when { - link.startsWith("fileIndex") -> link.split("/").drop(2).joinToString("/") // rm file segment - else -> link - } - - protected open fun validate( - expected: String, - actual: String, - bytes: ByteArray - ): Boolean { - if (!actual.startsWith(expected)) { - log.warn("Content type mismatch: $actual != $expected") - if (actual.startsWith("text/html")) { - log.warn("Response Error: ${String(bytes)}", Exception()) - } - return false + protected open fun get(href: String): SimpleHttpRequest { + val request = SimpleHttpRequest(Method.GET, URI(href)) + cookies?.forEach { cookie -> + request.addHeader("Cookie", "${cookie.name}=${cookie.value}") + } + return request } - return true - } - - protected open fun mimeType(relative: String): String { - val extension = relative.split(".").last().split("?").first() - val contentType = when (extension) { - "css" -> "text/css" - "js" -> "text/javascript" - "json" -> "application/json" - "pdf" -> "application/pdf" - "zip" -> "application/zip" - "tar" -> "application/x-tar" - "gz" -> "application/gzip" - "bz2" -> "application/bzip2" - "mp3" -> "audio/mpeg" - //"tsv" -> "text/tab-separated-values" - "csv" -> "text/csv" - "txt" -> "text/plain" - "xml" -> "text/xml" - "svg" -> "image/svg+xml" - "png" -> "image/png" - "jpg" -> "image/jpeg" - "jpeg" -> "image/jpeg" - "gif" -> "image/gif" - "ico" -> "image/x-icon" - "html" -> "text/html" - "htm" -> "text/html" - else -> "text/plain" + + protected open fun currentPageLinks(driver: WebDriver) = listOf( + driver.findElements(By.xpath("//a[@href]")).map { it?.getAttribute("href") }.toSet(), + driver.findElements(By.xpath("//img[@src]")).map { it?.getAttribute("src") }.toSet(), + driver.findElements(By.xpath("//link[@href]")).map { it?.getAttribute("href") }.toSet(), + driver.findElements(By.xpath("//script[@src]")).map { it?.getAttribute("src") }.toSet(), + driver.findElements(By.xpath("//source[@src]")).map { it?.getAttribute("src") }.toSet(), + ).flatten().filterNotNull() + + private fun currentPageLinks(html: String) = listOf( + Jsoup.parse(html).select("a[href]").map { it.attr("href") }.toSet(), + Jsoup.parse(html).select("img[src]").map { it.attr("src") }.toSet(), + Jsoup.parse(html).select("link[href]").map { it.attr("href") }.toSet(), + Jsoup.parse(html).select("script[src]").map { it.attr("src") }.toSet(), + Jsoup.parse(html).select("source[src]").map { it.attr("src") }.toSet(), + ).flatten().filterNotNull() + + protected open fun toAbsolute(base: String, vararg links: String) = links + .map { it.split("#").first() }.filter { it.isNotBlank() }.distinct() + .map { link -> + val newLink = when { + link.startsWith("http") -> link + else -> URI.create(base).resolve(link).toString() + } + newLink + } + + protected open fun toRelative(base: String, link: String): String? = when { + link.startsWith(base) -> toRelative( + base, + link.removePrefix(base).replace("/{2,}".toRegex(), "/").removePrefix("/") + ) // relativize + link.startsWith("http") -> null // absolute + else -> link // relative + } + + protected open fun toArchivePath(link: String): String = when { + link.startsWith("fileIndex") -> link.split("/").drop(2).joinToString("/") // rm file segment + else -> link } - return contentType - } - - protected open fun editPage(html: String): String { - val doc = org.jsoup.Jsoup.parse(html) - doc.select("#toolbar").remove() - doc.select("#namebar").remove() - doc.select("#main-input").remove() - doc.select("#footer").remove() - return doc.toString() - } - - override fun close() { - log.debug("Closing", Exception()) - driver.quit() - httpClient.close() - //driver.close() - //Companion.chromeDriverService.close() - } - - - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(Selenium2S3::class.java) - - init { - Runtime.getRuntime().addShutdownHook(Thread { - try { - } catch (e: Exception) { - log.warn("Error closing com.simiacryptus.skyenet.webui.util.Selenium2S3", e) + + protected open fun validate( + expected: String, + actual: String, + bytes: ByteArray + ): Boolean { + if (!actual.startsWith(expected)) { + log.warn("Content type mismatch: $actual != $expected") + if (actual.startsWith("text/html")) { + log.warn("Response Error: ${String(bytes)}", Exception()) + } + return false } - }) + return true } - fun chromeDriver(headless: Boolean = true, loadImages: Boolean = !headless): ChromeDriver { - val osname = System.getProperty("os.name") - val chromePath = when { - // Windows - osname.contains("Windows") -> listOf( - "C:\\Program Files\\Google\\Chrome\\Application\\chromedriver.exe", - "C:\\Program Files (x86)\\Google\\Chrome\\Application\\chromedriver.exe" - ) - // Ubuntu - osname.contains("Linux") -> listOf("/usr/bin/chromedriver") - else -> throw RuntimeException("Not implemented for $osname") - } - System.setProperty("webdriver.chrome.driver", - chromePath.find { File(it).exists() } ?: throw RuntimeException("Chrome not found")) - val options = ChromeOptions() - val args = mutableListOf() - if (headless) args += "--headless" - if (loadImages) args += "--blink-settings=imagesEnabled=false" - options.addArguments(*args.toTypedArray()) - options.setPageLoadTimeout(Duration.of(90, ChronoUnit.SECONDS)) - return ChromeDriver(chromeDriverService, options) + protected open fun mimeType(relative: String): String { + val extension = relative.split(".").last().split("?").first() + val contentType = when (extension) { + "css" -> "text/css" + "js" -> "text/javascript" + "json" -> "application/json" + "pdf" -> "application/pdf" + "zip" -> "application/zip" + "tar" -> "application/x-tar" + "gz" -> "application/gzip" + "bz2" -> "application/bzip2" + "mp3" -> "audio/mpeg" + //"tsv" -> "text/tab-separated-values" + "csv" -> "text/csv" + "txt" -> "text/plain" + "xml" -> "text/xml" + "svg" -> "image/svg+xml" + "png" -> "image/png" + "jpg" -> "image/jpeg" + "jpeg" -> "image/jpeg" + "gif" -> "image/gif" + "ico" -> "image/x-icon" + "html" -> "text/html" + "htm" -> "text/html" + else -> "text/plain" + } + return contentType } - private val chromeDriverService by lazy { ChromeDriverService.createDefaultService() } - fun setCookies( - driver: WebDriver, - cookies: Array?, - domain: String? = null - ) { - cookies?.forEach { cookie -> - try { - driver.manage().addCookie( - Cookie( - /* name = */ cookie.name, - /* value = */ cookie.value, - /* domain = */ cookie.domain ?: domain, - /* path = */ cookie.path, - /* expiry = */ if (cookie.maxAge <= 0) null else Date(cookie.maxAge * 1000L), - /* isSecure = */ cookie.secure, - /* isHttpOnly = */ cookie.isHttpOnly - ) - ) - } catch (e: Exception) { - log.warn("Error setting cookie: $cookie", e) + protected open fun editPage(html: String): String { + val doc = Jsoup.parse(html) + doc.select("#toolbar").remove() + doc.select("#namebar").remove() + doc.select("#main-input").remove() + doc.select("#footer").remove() + return doc.toString() + } + + override fun close() { + log.debug("Closing", Exception()) + driver.quit() + httpClient.close() + //driver.close() + //Companion.chromeDriverService.close() + } + + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(Selenium2S3::class.java) + + init { + Runtime.getRuntime().addShutdownHook(Thread { + try { + } catch (e: Exception) { + log.warn("Error closing com.simiacryptus.skyenet.webui.util.Selenium2S3", e) + } + }) + } + + fun chromeDriver(headless: Boolean = true, loadImages: Boolean = !headless): ChromeDriver { + val osname = System.getProperty("os.name") + val chromePath = when { + // Windows + osname.contains("Windows") -> listOf( + "C:\\Program Files\\Google\\Chrome\\Application\\chromedriver.exe", + "C:\\Program Files (x86)\\Google\\Chrome\\Application\\chromedriver.exe" + ) + // Ubuntu + osname.contains("Linux") -> listOf("/usr/bin/chromedriver") + else -> throw RuntimeException("Not implemented for $osname") + } + System.setProperty("webdriver.chrome.driver", + chromePath.find { File(it).exists() } ?: throw RuntimeException("Chrome not found")) + val options = ChromeOptions() + val args = mutableListOf() + if (headless) args += "--headless" + if (loadImages) args += "--blink-settings=imagesEnabled=false" + options.addArguments(*args.toTypedArray()) + options.setPageLoadTimeout(Duration.of(90, ChronoUnit.SECONDS)) + return ChromeDriver(chromeDriverService, options) + } + + private val chromeDriverService by lazy { ChromeDriverService.createDefaultService() } + fun setCookies( + driver: WebDriver, + cookies: Array?, + domain: String? = null + ) { + cookies?.forEach { cookie -> + try { + driver.manage().addCookie( + Cookie( + /* name = */ cookie.name, + /* value = */ cookie.value, + /* domain = */ cookie.domain ?: domain, + /* path = */ cookie.path, + /* expiry = */ if (cookie.maxAge <= 0) null else Date(cookie.maxAge * 1000L), + /* isSecure = */ cookie.secure, + /* isHttpOnly = */ cookie.isHttpOnly + ) + ) + } catch (e: Exception) { + log.warn("Error setting cookie: $cookie", e) + } + } } - } } - } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/TensorflowProjector.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/TensorflowProjector.kt index 1a8d9943..53ed506d 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/TensorflowProjector.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/TensorflowProjector.kt @@ -11,16 +11,15 @@ import com.simiacryptus.skyenet.core.platform.User import com.simiacryptus.skyenet.webui.application.ApplicationInterface class TensorflowProjector( - val api: API, - val dataStorage: StorageInterface, - val sessionID: Session, - val host: String, - val session: ApplicationInterface, - val userId: User?, + val api: API, + val dataStorage: StorageInterface, + val sessionID: Session, + val session: ApplicationInterface, + val userId: User?, ) { private fun toVectorMap(vararg words: String): Map { - val vectors = words.map {word -> + val vectors = words.map { word -> word to (api as OpenAIClient).createEmbedding( com.simiacryptus.jopenai.ApiModel.EmbeddingRequest( model = EmbeddingModels.AdaEmbedding.modelName, diff --git a/webui/src/test/kotlin/com/github/simiacryptus/diff/DiffUtilTest.kt b/webui/src/test/kotlin/com/github/simiacryptus/diff/DiffUtilTest.kt index b0c87fd4..fd7858c8 100644 --- a/webui/src/test/kotlin/com/github/simiacryptus/diff/DiffUtilTest.kt +++ b/webui/src/test/kotlin/com/github/simiacryptus/diff/DiffUtilTest.kt @@ -5,125 +5,67 @@ import org.junit.jupiter.api.Test class DiffUtilTest { -/* - @Test - fun testNoChanges() { - val original = listOf("line1", "line2", "line3") - val modified = listOf("line1", "line2", "line3") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - assertEquals("", formattedDiff, "There should be no diff for identical inputs.") - } -*/ + /* + @Test + fun testNoChanges() { + val original = listOf("line1", "line2", "line3") + val modified = listOf("line1", "line2", "line3") + val diffResults = DiffUtil.generateDiff(original, modified) + val formattedDiff = DiffUtil.formatDiff(diffResults) + assertEquals("", formattedDiff, "There should be no diff for identical inputs.") + } + */ -// @Test - fun testAdditions() { - val original = listOf("line1", "line2") - val modified = listOf("line1", "line2", "line3") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - val expectedDiff = """ - line1 - line2 - + line3 - """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent an addition.") - } - -// @Test - fun testDeletions() { - val original = listOf("line1", "line2", "line3") - val modified = listOf("line1", "line3") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - val expectedDiff = """ - line1 - - line2 - line3 - """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent a deletion.") - } - - @Test - fun testMixedChanges() { - val original = listOf("line1", "line2", "line4") - val modified = listOf("line1", "line3", "line4") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - val expectedDiff = """ + @Test + fun testMixedChanges() { + val original = listOf("line1", "line2", "line4") + val modified = listOf("line1", "line3", "line4") + val diffResults = DiffUtil.generateDiff(original, modified) + val formattedDiff = DiffUtil.formatDiff(diffResults) + val expectedDiff = """ line1 - line2 + line3 line4 """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent mixed changes.") - } + Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent mixed changes.") + } -// @Test - fun testContextLines() { - val original = listOf("line0", "line1", "line2", "line3", "line4") - val modified = listOf("line0", "line1", "changed_line2", "line3", "line4") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults, 1) - val expectedDiff = """ - line1 - - line2 - + changed_line2 - line3 - """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly include context lines.") - } - - @Test - fun testStartWithChange() { - val original = listOf("line1", "line2") - val modified = listOf("changed_line1", "line2") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - val expectedDiff = """ + @Test + fun testStartWithChange() { + val original = listOf("line1", "line2") + val modified = listOf("changed_line1", "line2") + val diffResults = DiffUtil.generateDiff(original, modified) + val formattedDiff = DiffUtil.formatDiff(diffResults) + val expectedDiff = """ - line1 + changed_line1 line2 """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent changes at the start.") - } - - @Test - fun testEndWithChange() { - val original = listOf("line1", "line2") - val modified = listOf("line1", "changed_line2") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults) - val expectedDiff = """ - line1 - - line2 - + changed_line2 - """.trimIndent() - Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent changes at the end.") - } + Assertions.assertEquals( + expectedDiff, + formattedDiff, + "The diff should correctly represent changes at the start." + ) + } -// @Test - fun testNoContextNeeded() { - val original = listOf("line1", "line2", "line3") - val modified = listOf("line1", "changed_line2", "line3") - val diffResults = DiffUtil.generateDiff(original, modified) - val formattedDiff = DiffUtil.formatDiff(diffResults, 0) - val expectedDiff = """ + @Test + fun testEndWithChange() { + val original = listOf("line1", "line2") + val modified = listOf("line1", "changed_line2") + val diffResults = DiffUtil.generateDiff(original, modified) + val formattedDiff = DiffUtil.formatDiff(diffResults) + val expectedDiff = """ line1 - line2 + changed_line2 - line3 """.trimIndent() - Assertions.assertEquals( - expectedDiff, - formattedDiff, - "The diff should correctly handle cases with no context lines." - ) - } + Assertions.assertEquals(expectedDiff, formattedDiff, "The diff should correctly represent changes at the end.") + } - @Test - fun testVerifyLLMPatch() { - val originalCode = """ + @Test + fun testVerifyLLMPatch() { + val originalCode = """ @@ -149,7 +91,7 @@ class DiffUtilTest { """.trimIndent() - val llmPatch = """ + val llmPatch = """ @@ -178,13 +120,13 @@ class DiffUtilTest { """.trimIndent() - val reconstructed = ApxPatchUtil.patch(originalCode, llmPatch) + val reconstructed = ApxPatchUtil.patch(originalCode, llmPatch) - val patchLines = DiffUtil.generateDiff(originalCode.lines(), reconstructed.lines()) + val patchLines = DiffUtil.generateDiff(originalCode.lines(), reconstructed.lines()) // println("\n\nPatched:\n\n") // patchLines.forEach { println(it) } - println("\n\nEcho Patch:\n\n") - DiffUtil.formatDiff(patchLines).lines().forEach { println(it) } - } + println("\n\nEcho Patch:\n\n") + DiffUtil.formatDiff(patchLines).lines().forEach { println(it) } + } } \ No newline at end of file diff --git a/webui/src/test/kotlin/com/github/simiacryptus/diff/IterativePatchUtilTest.kt b/webui/src/test/kotlin/com/github/simiacryptus/diff/IterativePatchUtilTest.kt index 7e4e2463..405f4385 100644 --- a/webui/src/test/kotlin/com/github/simiacryptus/diff/IterativePatchUtilTest.kt +++ b/webui/src/test/kotlin/com/github/simiacryptus/diff/IterativePatchUtilTest.kt @@ -66,28 +66,29 @@ class IterativePatchUtilTest { Assertions.assertEquals(expected.replace("\r\n", "\n"), result.replace("\r\n", "\n")) } - @Test - fun testPatchRemoveLine() { - val source = """ + @Test + fun testPatchRemoveLine() { + val source = """ line1 line2 line3 """.trimIndent() - val patch = """ + val patch = """ line1 - line2 line3 """.trimIndent() - val expected = """ + val expected = """ line1 line3 """.trimIndent() - val result = IterativePatchUtil.patch(source, patch) - Assertions.assertEquals(expected.replace("\r\n", "\n"), result.replace("\r\n", "\n")) - } - @Test - fun testFromData() { - val source = """ + val result = IterativePatchUtil.patch(source, patch) + Assertions.assertEquals(expected.replace("\r\n", "\n"), result.replace("\r\n", "\n")) + } + + @Test + fun testFromData() { + val source = """ function updateTabs() { document.querySelectorAll('.tab-button').forEach(button => { button.addEventListener('click', (event) => { // Ensure the event is passed as a parameter @@ -108,7 +109,7 @@ class IterativePatchUtilTest { }); } """.trimIndent() - val patch = """ + val patch = """ tabsParent.querySelectorAll('.tab-content').forEach(content => { const contentParent = content.closest('.tabs-container'); if (contentParent === tabsParent) { @@ -122,7 +123,7 @@ class IterativePatchUtilTest { } }); """.trimIndent() - val expected = """ + val expected = """ function updateTabs() { document.querySelectorAll('.tab-button').forEach(button => { button.addEventListener('click', (event) => { // Ensure the event is passed as a parameter @@ -145,10 +146,10 @@ class IterativePatchUtilTest { }); } """.trimIndent() - val result = IterativePatchUtil.patch(source, patch) - Assertions.assertEquals( - expected.replace("\r\n", "\n").replace("\\s{1,}".toRegex(), " "), - result.replace("\r\n", "\n").replace("\\s{1,}".toRegex(), " ") - ) - } + val result = IterativePatchUtil.patch(source, patch) + Assertions.assertEquals( + expected.replace("\r\n", "\n").replace("\\s{1,}".toRegex(), " "), + result.replace("\r\n", "\n").replace("\\s{1,}".toRegex(), " ") + ) + } } \ No newline at end of file diff --git a/webui/src/test/kotlin/com/simiacryptus/skyenet/webui/ActorTestAppServer.kt b/webui/src/test/kotlin/com/simiacryptus/skyenet/webui/ActorTestAppServer.kt index ec034340..a0586041 100644 --- a/webui/src/test/kotlin/com/simiacryptus/skyenet/webui/ActorTestAppServer.kt +++ b/webui/src/test/kotlin/com/simiacryptus/skyenet/webui/ActorTestAppServer.kt @@ -17,7 +17,6 @@ import com.simiacryptus.skyenet.webui.test.CodingActorTestApp import com.simiacryptus.skyenet.webui.test.ImageActorTestApp import com.simiacryptus.skyenet.webui.test.ParsedActorTestApp import com.simiacryptus.skyenet.webui.test.SimpleActorTestApp -import java.util.function.Function object ActorTestAppServer : com.simiacryptus.skyenet.webui.application.ApplicationDirectory(port = 8082) { @@ -28,22 +27,43 @@ object ActorTestAppServer : com.simiacryptus.skyenet.webui.application.Applicati val type: String? = null, ) - interface JokeParser : Function - override val childWebApps by lazy { listOf( - ChildWebApp("/test_simple", SimpleActorTestApp(SimpleActor("Translate the user's request into pig latin.", "PigLatin", model = ChatModels.GPT35Turbo))), - ChildWebApp("/test_parsed_joke", ParsedActorTestApp(ParsedActor( - resultClass = TestJokeDataStructure::class.java, - prompt = "Tell me a joke", - parsingModel = ChatModels.GPT35Turbo, - model = ChatModels.GPT35Turbo, - ))), + ChildWebApp( + "/test_simple", + SimpleActorTestApp( + SimpleActor( + "Translate the user's request into pig latin.", + "PigLatin", + model = ChatModels.GPT35Turbo + ) + ) + ), + ChildWebApp( + "/test_parsed_joke", ParsedActorTestApp( + ParsedActor( + resultClass = TestJokeDataStructure::class.java, + prompt = "Tell me a joke", + parsingModel = ChatModels.GPT35Turbo, + model = ChatModels.GPT35Turbo, + ) + ) + ), ChildWebApp("/images", ImageActorTestApp(ImageActor(textModel = ChatModels.GPT35Turbo))), - ChildWebApp("/test_coding_scala", CodingActorTestApp(CodingActor(ScalaLocalInterpreter::class, model = ChatModels.GPT35Turbo))), - ChildWebApp("/test_coding_kotlin", CodingActorTestApp(CodingActor(KotlinInterpreter::class, model = ChatModels.GPT35Turbo))), - ChildWebApp("/test_coding_groovy", CodingActorTestApp(CodingActor(GroovyInterpreter::class, model = ChatModels.GPT35Turbo))), - )} + ChildWebApp( + "/test_coding_scala", + CodingActorTestApp(CodingActor(ScalaLocalInterpreter::class, model = ChatModels.GPT35Turbo)) + ), + ChildWebApp( + "/test_coding_kotlin", + CodingActorTestApp(CodingActor(KotlinInterpreter::class, model = ChatModels.GPT35Turbo)) + ), + ChildWebApp( + "/test_coding_groovy", + CodingActorTestApp(CodingActor(GroovyInterpreter::class, model = ChatModels.GPT35Turbo)) + ), + ) + } override val toolServlet: ToolServlet? get() = null @JvmStatic