From 96720aa7ea22536a7eee5f7f6de6f2fed6cbfa96 Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Sun, 7 Jul 2024 16:44:35 -0400 Subject: [PATCH] 1.0.82 (#87) * 1.0.82 * improved diff/patch * Update AddApplyFileDiffLinks.kt * Update functions.js * wip * Update IterativePatchUtil.kt * wip * wip * wip * Update IterativePatchUtil.kt * Update IterativePatchUtil.kt * wip * wip * wip * Update SocketManagerBase.kt * wip * wip * Update SocketManagerBase.kt * Update SocketManagerBase.kt * Update IterativePatchUtil.kt * Update IterativePatchUtil.kt * wip * Update IterativePatchUtil.kt * Update IterativePatchUtil.kt * Update SocketManagerBase.kt * Update SocketManagerBase.kt * Update SocketManagerBase.kt * wip * Update IterativePatchUtilTest.kt * wip * diff patch test * fix * Update SocketManagerBase.kt * Update FilePatchTestApp.kt * Update IterativePatchUtil.kt * Update IterativePatchUtil.kt --- .gitignore | 1 + .../skyenet/core/platform/AwsPlatform.kt | 2 +- gradle.properties | 3 +- webui/build.gradle.kts | 27 +- .../simiacryptus/diff/AddApplyDiffLinks.kt | 9 +- .../diff/AddApplyFileDiffLinks.kt | 60 +- .../simiacryptus/diff/IterativePatchUtil.kt | 798 +++++++++++++----- .../webui/application/ApplicationDirectory.kt | 2 +- .../webui/session/SocketManagerBase.kt | 95 +-- .../skyenet/webui/test/FilePatchTestApp.kt | 66 ++ .../skyenet/webui/util/EncryptFiles.kt | 8 +- webui/src/main/resources/application/chat.js | 75 +- .../main/resources/application/functions.js | 20 +- .../src/main/resources/application/index.html | 14 +- webui/src/main/resources/application/main.js | 18 +- webui/src/main/resources/application/tabs.js | 99 ++- .../diff/IterativePatchUtilTest.kt | 141 +++- .../skyenet/webui/ActorTestAppServer.kt | 25 +- .../client_secret_google_oauth.json.kms | 2 +- 19 files changed, 1035 insertions(+), 430 deletions(-) create mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/FilePatchTestApp.kt diff --git a/.gitignore b/.gitignore index 6babf813..b925fc56 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build/ openai.key *.log *.log.* +client_secret_google_oauth.json diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt index 1b594d10..947b930d 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt @@ -84,7 +84,7 @@ open class AwsPlatform( fun get() = try { AwsPlatform() } catch (e: Throwable) { - log.info("Error initializing AWS platform", e) + log.warn("Error initializing AWS platform", e) null } } diff --git a/gradle.properties b/gradle.properties index 72fd3df2..f782053d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,4 +1,5 @@ # Gradle Releases -> https://github.com/gradle/gradle/releases libraryGroup = com.simiacryptus.skyenet -libraryVersion = 1.0.81 +libraryVersion = 1.0.82 gradleVersion = 7.6.1 +kotlin.daemon.jvmargs=-Xmx2g diff --git a/webui/build.gradle.kts b/webui/build.gradle.kts index 083bb6ff..7ff978d6 100644 --- a/webui/build.gradle.kts +++ b/webui/build.gradle.kts @@ -35,22 +35,31 @@ val jetty_version = "11.0.18" val jackson_version = "2.17.0" dependencies { - implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.63") + implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.0.63") { + exclude(group = "org.slf4j", module = "slf4j-api") + } implementation(project(":core")) implementation(project(":kotlin")) implementation("org.seleniumhq.selenium:selenium-chrome-driver:4.16.1") - compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.21.9") compileOnly("org.jsoup:jsoup:1.17.2") - implementation("com.google.zxing:core:3.5.3") implementation("com.google.zxing:javase:3.5.3") - implementation("org.openapitools:openapi-generator:7.3.0") - implementation("org.openapitools:openapi-generator-cli:7.3.0") + compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.21.9") + runtimeOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.21.9") + implementation("org.openapitools:openapi-generator:7.3.0") { + exclude(group = "org.slf4j", module = "slf4j-api") + exclude(group = "org.slf4j", module = "slf4j-ext") + exclude(group = "org.slf4j", module = "slf4j-simple") + } + compileOnly("org.openapitools:openapi-generator-cli:7.3.0") { + exclude(group = "org.slf4j", module = "slf4j-api") + } + testRuntimeOnly("org.openapitools:openapi-generator-cli:7.3.0") implementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version) implementation(group = "org.eclipse.jetty", name = "jetty-servlet", version = jetty_version) @@ -71,7 +80,9 @@ dependencies { testImplementation(project(":kotlin")) testImplementation(project(":scala")) - implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.2.3") + implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.3.1") { + exclude(group = "org.slf4j", module = "slf4j-api") + } implementation(group = "com.fasterxml.jackson.core", name = "jackson-core", version = jackson_version) implementation(group = "com.fasterxml.jackson.core", name = "jackson-databind", version = jackson_version) @@ -85,7 +96,9 @@ dependencies { implementation(group = "commons-codec", name = "commons-codec", version = "1.16.0") implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9") - testImplementation(group = "org.slf4j", name = "slf4j-simple", version = "2.0.9") + runtimeOnly(group = "org.slf4j", name = "slf4j-simple", version = "2.0.9") + testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.4.14") + testImplementation(group = "ch.qos.logback", name = "logback-core", version = "1.4.14") testImplementation(kotlin("script-runtime")) testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt index 7ff76c44..1898b6ae 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt @@ -22,7 +22,7 @@ fun SocketManagerBase.addApplyDiffLinks( val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) - var newCode = IterativePatchUtil.patch(code, diff) + var newCode = IterativePatchUtil.applyPatch(code, diff) newCode = newCode.replace("\r", "") val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) @@ -58,12 +58,7 @@ fun SocketManagerBase.addApplyDiffLinks( } })!! val patch = patch(code(), diffVal).newCode - val test1 = DiffUtil.formatDiff( - DiffUtil.generateDiff( - code().replace("\r", "").lines(), - patch.lines() - ) - ) + val test1 = IterativePatchUtil.generatePatch(code().replace("\r", ""), patch) val patchRev = patch( code().lines().reversed().joinToString("\n"), diffVal.lines().reversed().joinToString("\n") diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt index 4d55217f..a2344e97 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt @@ -3,7 +3,6 @@ package com.simiacryptus.diff import com.simiacryptus.jopenai.API import com.simiacryptus.jopenai.OpenAIClient import com.simiacryptus.jopenai.models.ChatModels -import com.simiacryptus.skyenet.AgentPatterns import com.simiacryptus.skyenet.AgentPatterns.displayMapInTabs import com.simiacryptus.skyenet.core.actors.SimpleActor import com.simiacryptus.skyenet.set @@ -28,7 +27,7 @@ fun SocketManagerBase.addApplyFileDiffLinks( api: API, ): String { val initiator = "(?s)```[\\w]*\n".toRegex() - if(response.contains(initiator) && !response.split(initiator, 2)[1].contains("\n```(?![^\n])".toRegex())) { + if (response.contains(initiator) && !response.split(initiator, 2)[1].contains("\n```(?![^\n])".toRegex())) { // Single diff block without the closing ``` due to LLM limitations... add it back return addApplyFileDiffLinks( root, @@ -45,7 +44,7 @@ fun SocketManagerBase.addApplyFileDiffLinks( val diffs: List> = findAll.filter { block -> val header = headers.lastOrNull { it.first.endInclusive < block.range.start } val filename = resolve(root, header?.second ?: "Unknown") - when { + when { !root.toFile().resolve(filename).exists() -> false //block.groupValues[1] == "diff" -> true else -> true @@ -55,7 +54,7 @@ fun SocketManagerBase.addApplyFileDiffLinks( val codeblocks = findAll.filter { block -> val header = headers.lastOrNull { it.first.endInclusive < block.range.start } val filename = resolve(root, header?.second ?: "Unknown") - when { + when { root.toFile().resolve(filename).exists() -> false block.groupValues[1] == "diff" -> false else -> true @@ -111,12 +110,7 @@ fun SocketManagerBase.addApplyFileDiffLinks( """ |```diff |${ - DiffUtil.formatDiff( - DiffUtil.generateDiff( - prevCode.lines(), - codeValue.lines() - ) - ) + IterativePatchUtil.generatePatch(prevCode, codeValue) } |``` """.trimMargin(), ui = ui @@ -164,7 +158,8 @@ private fun SocketManagerBase.renderDiffBlock( diffVal: String, handle: (Map) -> Unit, ui: ApplicationInterface, - api: API? + api: API?, + watch: Boolean = false, ): String { val diffTask = ui.newTask(root = false) @@ -183,9 +178,6 @@ private fun SocketManagerBase.renderDiffBlock( val patch2TaskSB = patch2Task.add("") - - - val filepath = path(root, filename) val prevCode = load(filepath) val relativize = try { @@ -204,8 +196,7 @@ private fun SocketManagerBase.renderDiffBlock( var newCode = patch(originalCode, diffVal) - - val verifyFwdTabs = if(!newCode.isValid) displayMapInTabs( + val verifyFwdTabs = if (!newCode.isValid) displayMapInTabs( mapOf( "Code" to (prevCodeTask?.placeholder ?: ""), "Preview" to (newCodeTask?.placeholder ?: ""), @@ -255,7 +246,7 @@ private fun SocketManagerBase.renderDiffBlock( applydiffTask.error(null, e) } } - if(!newCode.isValid) { + if (!newCode.isValid) { val fixPatchLink = hrefLink("Fix Patch", classname = "href-link cmd-button") { try { val header = fixTask.header("Attempting to fix patch...") @@ -302,12 +293,7 @@ private fun SocketManagerBase.renderDiffBlock( ) val echoDiff = try { - DiffUtil.formatDiff( - DiffUtil.generateDiff( - prevCode.lines(), - newCode.newCode.lines() - ) - ) + IterativePatchUtil.generatePatch(prevCode, newCode.newCode) } catch (e: Throwable) { renderMarkdown("```\n${e.stackTraceToString()}\n```", ui = ui) } @@ -336,7 +322,7 @@ private fun SocketManagerBase.renderDiffBlock( ) answer = ui.socketManager?.addApplyFileDiffLinks(root, answer, handle, ui, api) ?: answer header?.clear() - fixTask.complete(answer) + fixTask.complete(renderMarkdown(answer)) } catch (e: Throwable) { log.error("Error in fix patch", e) } @@ -387,12 +373,7 @@ private fun SocketManagerBase.renderDiffBlock( if (!isApplied && thisFilehash != filehash) { val newCode = patch(prevCode, diffVal) val echoDiff = try { - DiffUtil.formatDiff( - DiffUtil.generateDiff( - prevCode.lines(), - newCode.newCode.lines() - ) - ) + IterativePatchUtil.generatePatch(prevCode, newCode.newCode) } catch (e: Throwable) { renderMarkdown("```\n${e.stackTraceToString()}\n```", ui = ui) } @@ -423,12 +404,7 @@ private fun SocketManagerBase.renderDiffBlock( diffVal.reverseLines() ).newCode.lines().reversed().joinToString("\n") val echoDiff2 = try { - DiffUtil.formatDiff( - DiffUtil.generateDiff( - prevCode.lines(), - newCode2.lines(), - ) - ) + IterativePatchUtil.generatePatch(prevCode, newCode2) } catch (e: Throwable) { renderMarkdown("```\n${e.stackTraceToString()}\n```", ui = ui) } @@ -458,7 +434,7 @@ private fun SocketManagerBase.renderDiffBlock( filehash = thisFilehash } } - if (!isApplied) { + if (!isApplied && watch) { scheduledThreadPoolExecutor.schedule(scheduledFn, 1000, TimeUnit.MILLISECONDS) } } catch (e: Throwable) { @@ -466,7 +442,13 @@ private fun SocketManagerBase.renderDiffBlock( } } scheduledThreadPoolExecutor.schedule(scheduledFn, 1000, TimeUnit.MILLISECONDS) - val newValue = mainTabs + "\n" + applydiffTask.placeholder + val newValue = if (newCode.isValid) { + mainTabs + "\n" + applydiffTask.placeholder + } else { + mainTabs + """ +
Warning: The patch is not valid. Please fix the patch before applying.
+ """.trimIndent() + applydiffTask.placeholder + } return newValue } @@ -476,7 +458,7 @@ private val patch = { code: String, diff: String -> val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) - var newCode = IterativePatchUtil.patch(code, diff) + var newCode = IterativePatchUtil.applyPatch(code, diff) newCode = newCode.replace("\r", "") val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt b/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt index 6a4edbdd..45dfb2d8 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt @@ -1,21 +1,31 @@ +@file:Suppress("LoggingSimilarMessage") + package com.simiacryptus.diff import org.apache.commons.text.similarity.LevenshteinDistance import org.slf4j.LoggerFactory -import kotlin.math.floor +import kotlin.math.max +import kotlin.math.min object IterativePatchUtil { + enum class LineType { CONTEXT, ADD, DELETE } - private val log = LoggerFactory.getLogger(IterativePatchUtil::class.java) + // Tracks the nesting depth of different bracket types + data class LineMetrics( + var parenthesesDepth: Int = 0, + var squareBracketsDepth: Int = 0, + var curlyBracesDepth: Int = 0 + ) - enum class LineType { CONTEXT, ADD, DELETE } - class LineRecord( + // Represents a single line in the source or patch text + data class LineRecord( val index: Int, val line: String?, var previousLine: LineRecord? = null, var nextLine: LineRecord? = null, var matchingLine: LineRecord? = null, - var type: LineType = LineType.CONTEXT + var type: LineType = LineType.CONTEXT, + var metrics: LineMetrics = LineMetrics() ) { override fun toString(): String { val sb = StringBuilder() @@ -27,17 +37,63 @@ object IterativePatchUtil { } sb.append(" ") sb.append(line) + sb.append(" [({:${metrics.parenthesesDepth}, [:${metrics.squareBracketsDepth}, {:${metrics.curlyBracesDepth}]") return sb.toString() } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as LineRecord + + if (index != other.index) return false + if (line != other.line) return false + if (type != other.type) return false + if (metrics != other.metrics) return false + + return true + } + + override fun hashCode(): Int { + var result = index + result = 31 * result + (line?.hashCode() ?: 0) + result = 31 * result + type.hashCode() + result = 31 * result + metrics.hashCode() + return result + } + + } /** - * Normalizes a line by removing all whitespace. - * @param line The line to normalize. - * @return The normalized line. + * Generates an optimal patch by comparing two code files. + * @param oldCode The original code. + * @param newCode The new code. + * @return The generated patch as a string. */ - private fun normalizeLine(line: String): String { - return line.replace("\\s".toRegex(), "") + fun generatePatch(oldCode: String, newCode: String): String { + log.info("Starting patch generation process") + val sourceLines = parseLines(oldCode) + val newLines = parseLines(newCode) + link(sourceLines, newLines) + log.debug("Parsed and linked source lines: ${sourceLines.size}, new lines: ${newLines.size}") + val diff1 = diffLines(sourceLines, newLines) + val diff = truncateContext(diff1).toMutableList() + fixPatchLineOrder(diff) + annihilateNoopLinePairs(diff) + log.debug("Generated diff with ${diff.size} lines after processing") + val patch = StringBuilder() + // Generate the patch text + diff.forEach { line -> + when (line.type) { + LineType.CONTEXT -> patch.append(" ${line.line}\n") + LineType.ADD -> patch.append("+ ${line.line}\n") + LineType.DELETE -> patch.append("- ${line.line}\n") + } + } + log.info("Patch generation completed") + return patch.toString().trimEnd() } /** @@ -46,13 +102,164 @@ object IterativePatchUtil { * @param patch The patch to apply. * @return The text after the patch has been applied. */ - fun patch(source: String, patch: String): String { - log.info("Starting patch process") + fun applyPatch(source: String, patch: String): String { + log.info("Starting patch application process") // Parse the source and patch texts into lists of line records val sourceLines = parseLines(source) - val patchLines = parsePatchLines(patch) - log.debug("Parsed source lines: ${sourceLines.size}, patch lines: ${patchLines.size}") + var patchLines = parsePatchLines(patch) + log.debug("Parsed source lines: ${sourceLines.size}, initial patch lines: ${patchLines.size}") + link(sourceLines, patchLines) + + // Filter out empty lines in the patch + patchLines = patchLines.filter { it.line?.let { normalizeLine(it).isEmpty() } == false } + log.debug("Filtered patch lines: ${patchLines.size}") + log.info("Generating patched text") + + val result = generatePatchedText(sourceLines, patchLines) + val generatePatchedTextUsingLinks = result.joinToString("\n").trim() + log.info("Patch application completed") + + return generatePatchedTextUsingLinks + } + + private fun annihilateNoopLinePairs(diff: MutableList) { + log.debug("Starting annihilation of no-op line pairs") + val toRemove = mutableListOf>() + var i = 0 + while (i < diff.size-1) { + if (diff[i].type == LineType.DELETE) { + var j = i + 1 + while (j < diff.size && diff[j].type != LineType.CONTEXT) { + if (diff[j].type == LineType.ADD && + normalizeLine(diff[i].line ?: "") == normalizeLine(diff[j].line ?: "") + ) { + toRemove.add(Pair(i, j)) + break + } + j++ + } + } + i++ + } + // Remove the pairs in reverse order to maintain correct indices + toRemove.flatMap { listOf(it.first, it.second) }.distinct().sortedDescending().forEach { diff.removeAt(it) } + log.debug("Removed ${toRemove.size} no-op line pairs") + } + + private fun diffLines( + sourceLines: List, + newLines: List + ): MutableList { + val diff = mutableListOf() + log.debug("Starting diff generation") + var sourceIndex = 0 + var newIndex = 0 + // Generate raw patch without limited context windows + while (sourceIndex < sourceLines.size || newIndex < newLines.size) { + when { + sourceIndex >= sourceLines.size -> { + // Add remaining new lines + diff.add(newLines[newIndex].copy(type = LineType.ADD)) + newIndex++ + } + + newIndex >= newLines.size -> { + // Delete remaining source lines + diff.add(sourceLines[sourceIndex].copy(type = LineType.DELETE)) + sourceIndex++ + } + + sourceLines[sourceIndex].matchingLine == newLines[newIndex] && + normalizeLine(sourceLines[sourceIndex].line ?: "") == normalizeLine( + newLines[newIndex].line ?: "" + ) -> { + // Lines match, add as context + diff.add(sourceLines[sourceIndex].copy(type = LineType.CONTEXT)) + sourceIndex++ + newIndex++ + } + + sourceLines[sourceIndex].matchingLine == null || + normalizeLine(sourceLines[sourceIndex].line ?: "") != normalizeLine( + newLines[newIndex].line ?: "" + ) -> { + // Source line has no match, it's a deletion + diff.add(sourceLines[sourceIndex].copy(type = LineType.DELETE)) + sourceIndex++ + } + + else -> { + // New line has no match in source, it's an addition + diff.add(newLines[newIndex].copy(type = LineType.ADD)) + newIndex++ + } + } + } + log.debug("Generated diff with ${diff.size} lines") + return diff + } + + private fun truncateContext(diff: MutableList): MutableList { + val contextSize = 3 // Number of context lines before and after changes + log.debug("Truncating context with size $contextSize") + val truncatedDiff = mutableListOf() + var inChange = false + val contextBuffer = mutableListOf() + var lastChangeIndex = -1 + for (i in diff.indices) { + val line = diff[i] + when { + line.type != LineType.CONTEXT -> { + if (!inChange) { + // Start of a change, add buffered context + truncatedDiff.addAll(contextBuffer.takeLast(contextSize)) + contextBuffer.clear() + } + truncatedDiff.add(line) + inChange = true + lastChangeIndex = i + } + + inChange -> { + contextBuffer.add(line) + if (contextBuffer.size == contextSize) { + // End of a change, add buffered context + truncatedDiff.addAll(contextBuffer) + contextBuffer.clear() + inChange = false + } + } + + else -> { + contextBuffer.add(line) + if (contextBuffer.size > contextSize) { + contextBuffer.removeAt(0) + } + } + } + } + // Add trailing context after the last change + if (lastChangeIndex != -1) { + val trailingContext = diff.subList(lastChangeIndex + 1, min(diff.size, lastChangeIndex + 1 + contextSize)) + truncatedDiff.addAll(trailingContext) + } + log.debug("Truncated diff size: ${truncatedDiff.size}") + return truncatedDiff + } + + /** + * Normalizes a line by removing all whitespace. + * @param line The line to normalize. + * @return The normalized line. + */ + private fun normalizeLine(line: String): String { + return line.replace("\\s".toRegex(), "") + } + private fun link( + sourceLines: List, + patchLines: List + ) { // Step 1: Link all unique lines in the source and patch that match exactly log.info("Step 1: Linking unique matching lines") linkUniqueMatchingLines(sourceLines, patchLines) @@ -60,89 +267,208 @@ object IterativePatchUtil { // Step 2: Link all exact matches in the source and patch which are adjacent to established links log.info("Step 2: Linking adjacent matching lines") linkAdjacentMatchingLines(sourceLines) + log.info("Step 3: Performing subsequence linking") - // Step 3: Establish a distance metric for matches based on Levenshtein distance and distance to established links. - //linkByLevenshteinDistance(sourceLines, patchLines) + subsequenceLinking(sourceLines, patchLines) + } + + private fun subsequenceLinking( + sourceLines: List, + patchLines: List, + depth: Int = 0 + ) { + log.debug("Subsequence linking at depth $depth") + if (depth > 10 || sourceLines.isEmpty() || patchLines.isEmpty()) { + return // Base case: prevent excessive recursion + } + val sourceSegment = sourceLines.filter { it.matchingLine == null } + val patchSegment = patchLines.filter { it.matchingLine == null } + if (sourceSegment.isNotEmpty() && patchSegment.isNotEmpty()) { + var matchedLines = linkUniqueMatchingLines(sourceSegment, patchSegment) + matchedLines += linkAdjacentMatchingLines(sourceSegment) + if (matchedLines == 0) { + matchedLines += matchFirstBrackets(sourceSegment, patchSegment) + } + if (matchedLines > 0) { + subsequenceLinking(sourceSegment, patchSegment, depth + 1) + } + log.debug("Matched $matchedLines lines in subsequence linking at depth $depth") - // Generate the patched text using the established links - log.info("Generating patched text using established links") - log.info("Patch process completed") - return generatePatchedTextUsingLinks(sourceLines, patchLines) + } } - /** - * Generates the final patched text using the links established between the source and patch lines. - * @param sourceLines The source lines with established links. - * @param patchLines The patch lines with established links. - * @return The final patched text. - */ - private fun generatePatchedTextUsingLinks(sourceLines: List, patchLines: List): String { - val patchedTextBuilder = StringBuilder() - val sourceLineBuffer = sourceLines.toMutableList() + private fun generatePatchedText( + sourceLines: List, + patchLines: List, + ): List { log.debug("Starting to generate patched text") + val patchedText: MutableList = mutableListOf() + val usedPatchLines = mutableSetOf() + var sourceIndex = 0 + var lastMatchedPatchIndex = -1 + while (sourceIndex < sourceLines.size) { + val codeLine = sourceLines[sourceIndex] + when { + codeLine.matchingLine?.type == LineType.DELETE -> { + val patchLine = codeLine.matchingLine!! + var patchIndex = patchLines.indexOf(patchLine) + log.debug("Deleting line: {}", codeLine) + updateContext(lastMatchedPatchIndex, patchIndex, patchLines, usedPatchLines, patchedText) + patchIndex = checkBeforeForInserts(sourceIndex, sourceLines, usedPatchLines, patchedText) - // Add any leading lines from the source that are not in the patch - while (sourceLineBuffer.firstOrNull()?.matchingLine == null) { - val line = sourceLineBuffer.removeFirstOrNull() ?: break - patchedTextBuilder.appendLine(line.line) - log.debug("Added leading source line: ${line}") - } - log.debug("Added ${patchedTextBuilder.lines().size} leading lines from source") - - // Add any leading 'add' lines from the patch - val patchLines = patchLines.toMutableList() - while (patchLines.firstOrNull()?.type == LineType.ADD) { - val line = patchLines.removeFirstOrNull() ?: break - patchedTextBuilder.appendLine(line.line) - log.debug("Added leading patch line: ${line}") - } + // Delete the line -- do not add to patched text - log.debug("Added ${patchedTextBuilder.lines().size} leading 'add' lines from patch") + patchIndex = patchLines.indexOf(patchLine) + patchIndex = checkAfterForInserts(patchIndex, patchLines, usedPatchLines, patchedText) - // Process the rest of the lines - while (sourceLineBuffer.isNotEmpty()) { - // Copy all lines until the next matched line - val codeLine = sourceLineBuffer.removeFirst() - when { - codeLine.matchingLine == null -> { - // If the line is not matched and is adjacent to a non-matched line, add it as a context line - log.debug("Processing unmatched line: ${codeLine}") - if (codeLine.nextLine?.matchingLine == null || codeLine.previousLine?.matchingLine == null) { - patchedTextBuilder.appendLine(codeLine.line) - log.debug("Added unmatched line: ${codeLine}") - } + usedPatchLines.add(patchLine) + lastMatchedPatchIndex = patchIndex + sourceIndex++ } - codeLine.matchingLine!!.type == LineType.DELETE -> log.debug("Skipped deleted line: ${codeLine}") - codeLine.matchingLine!!.type == LineType.CONTEXT -> { - patchedTextBuilder.appendLine(codeLine.line) - log.debug("Added context line: ${codeLine}") + codeLine.matchingLine != null -> { + val patchLine = codeLine.matchingLine!! + var patchIndex = patchLines.indexOf(patchLine) + log.debug("Patching line: {} <-> {}", codeLine, patchLine) + // Add context lines between last match and current match + updateContext(lastMatchedPatchIndex, patchIndex, patchLines, usedPatchLines, patchedText) + patchIndex = checkBeforeForInserts(patchIndex, patchLines, usedPatchLines, patchedText) + + patchedText.add(patchLine.line ?: "") // Add the patched line + + patchIndex = patchLines.indexOf(patchLine) + patchIndex = checkAfterForInserts(patchIndex, patchLines, usedPatchLines, patchedText) + + usedPatchLines.add(patchLine) + lastMatchedPatchIndex = patchIndex + sourceIndex++ } - codeLine.matchingLine!!.type == LineType.ADD -> { - patchedTextBuilder.appendLine(codeLine.line) - log.debug("Added modified line: ${codeLine}") + + else -> { + // Check if this line is a context line in the patch + val contextPatchLine = patchLines.find { it.type == LineType.CONTEXT && it.line == codeLine.line } + if (contextPatchLine != null) { + log.debug("Added context line: {}", codeLine) + patchedText.add(contextPatchLine.line ?: "") + usedPatchLines.add(contextPatchLine) + } else { + log.debug("Added unmatched source line: {}", codeLine) + patchedText.add(codeLine.line ?: "") + } + sourceIndex++ } + } + } + // Add remaining context lines after the last match + if (lastMatchedPatchIndex != -1) { + for (i in lastMatchedPatchIndex + 1 until patchLines.size) { + val contextLine = patchLines[i] + if (contextLine.type == LineType.CONTEXT && !usedPatchLines.contains(contextLine)) { + patchedText.add(contextLine.line ?: "") + usedPatchLines.add(contextLine) + } + } + } + // Add any remaining unused ADD lines from the patch + patchLines.filter { it.type == LineType.ADD && !usedPatchLines.contains(it) }.forEach { line -> + log.debug("Added remaining patch line: {}", line) + patchedText.add(line.line ?: "") + } + log.debug("Generated patched text with ${patchedText.size} lines") + return patchedText + } - // Add lines marked as ADD in the patch following the current matched line - var nextPatchLine = codeLine.matchingLine?.nextLine - while (nextPatchLine != null && nextPatchLine.matchingLine == null) { - when (nextPatchLine.type) { - LineType.ADD -> { - patchedTextBuilder.appendLine(nextPatchLine.line) - log.debug("Added new line from patch: ${nextPatchLine}") - } - LineType.CONTEXT -> { - patchedTextBuilder.appendLine(nextPatchLine.line) - log.debug("Added context line from patch: ${nextPatchLine}") - } - LineType.DELETE -> log.debug("Skipped deleted line from patch: ${nextPatchLine}") + private fun updateContext( + lastMatchedPatchIndex: Int, + patchIndex: Int, + patchLines: List, + usedPatchLines: MutableSet, + patchedText: MutableList + ) { + if (lastMatchedPatchIndex != -1) { + for (i in lastMatchedPatchIndex + 1 until patchIndex) { + val contextLine = patchLines[i] + if (contextLine.type == LineType.CONTEXT && !usedPatchLines.contains(contextLine)) { + patchedText.add(contextLine.line ?: "") + usedPatchLines.add(contextLine) } - nextPatchLine = nextPatchLine.nextLine } } - log.debug("Finished generating patched text") - return patchedTextBuilder.toString().trimEnd() + } + + private fun checkAfterForInserts( + patchIndex: Int, + patchLines: List, + usedPatchLines: MutableSet, + patchedText: MutableList + ): Int { + var patchIndex1 = patchIndex + while (patchIndex1 < patchLines.size - 1) { + val nextPatchLine = patchLines[++patchIndex1] + if (nextPatchLine.type == LineType.ADD && !usedPatchLines.contains(nextPatchLine)) { + log.debug("Added unmatched patch line: {}", nextPatchLine) + patchedText.add(nextPatchLine.line ?: "") + usedPatchLines.add(nextPatchLine) + } else { + break + } + } + return patchIndex1 + } + + private fun checkBeforeForInserts( + patchIndex: Int, + patchLines: List, + usedPatchLines: MutableSet, + patchedText: MutableList + ): Int { + var patchIndex1 = patchIndex + while (patchIndex1 > 0) { + val prevPatchLine = patchLines[--patchIndex1] + if (prevPatchLine.type == LineType.ADD && !usedPatchLines.contains(prevPatchLine)) { + log.debug("Added unmatched patch line: {}", prevPatchLine) + patchedText.add(prevPatchLine.line ?: "") + usedPatchLines.add(prevPatchLine) + } else { + break + } + } + return patchIndex1 + } + + private fun matchFirstBrackets(sourceLines: List, patchLines: List): Int { + log.debug("Starting to match first brackets") + log.debug("Starting to link unique matching lines") + // Group source lines by their normalized content + val sourceLineMap = sourceLines.filter { + it.line?.lineMetrics() != LineMetrics() + }.groupBy { normalizeLine(it.line!!) } + // Group patch lines by their normalized content, excluding ADD lines + val patchLineMap = patchLines.filter { + it.line?.lineMetrics() != LineMetrics() + }.filter { + when (it.type) { + LineType.ADD -> false // ADD lines are not matched to source lines + else -> true + } + }.groupBy { normalizeLine(it.line!!) } + log.debug("Created source and patch line maps") + + // Find intersecting keys (matching lines) and link them + val matched = sourceLineMap.keys.intersect(patchLineMap.keys) + matched.forEach { key -> + val sourceGroup = sourceLineMap[key]!! + val patchGroup = patchLineMap[key]!! + for (i in 0 until min(sourceGroup.size, patchGroup.size)) { + sourceGroup[i].matchingLine = patchGroup[i] + patchGroup[i].matchingLine = sourceGroup[i] + log.debug("Linked matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") + } + } + val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } + log.debug("Finished matching first brackets. Matched $matchedCount lines") + return matched.sumOf { sourceLineMap[it]!!.size } } /** @@ -150,9 +476,11 @@ object IterativePatchUtil { * @param sourceLines The source lines. * @param patchLines The patch lines. */ - private fun linkUniqueMatchingLines(sourceLines: List, patchLines: List) { - log.debug("Starting to link unique matching lines") + private fun linkUniqueMatchingLines(sourceLines: List, patchLines: List): Int { + log.debug("Starting to link unique matching lines. Source lines: ${sourceLines.size}, Patch lines: ${patchLines.size}") + // Group source lines by their normalized content val sourceLineMap = sourceLines.groupBy { normalizeLine(it.line!!) } + // Group patch lines by their normalized content, excluding ADD lines val patchLineMap = patchLines.filter { when (it.type) { LineType.ADD -> false // ADD lines are not matched to source lines @@ -161,170 +489,102 @@ object IterativePatchUtil { }.groupBy { normalizeLine(it.line!!) } log.debug("Created source and patch line maps") - sourceLineMap.keys.intersect(patchLineMap.keys).forEach { key -> - val sourceLine = sourceLineMap[key]?.singleOrNull() - val patchLine = patchLineMap[key]?.singleOrNull() - if (sourceLine != null && patchLine != null) { - sourceLine.matchingLine = patchLine - patchLine.matchingLine = sourceLine - log.debug("Linked unique matching lines: Source[${sourceLine.index}]: ${sourceLine.line} <-> Patch[${patchLine.index}]: ${patchLine.line}") + // Find intersecting keys (matching lines) and link them + val matched = sourceLineMap.keys.intersect(patchLineMap.keys).filter { + sourceLineMap[it]?.size == patchLineMap[it]?.size + } + matched.forEach { key -> + val sourceGroup = sourceLineMap[key]!! + val patchGroup = patchLineMap[key]!! + for (i in sourceGroup.indices) { + sourceGroup[i].matchingLine = patchGroup[i] + patchGroup[i].matchingLine = sourceGroup[i] + log.debug("Linked unique matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") } } - log.debug("Finished linking unique matching lines") + val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } + log.debug("Finished linking unique matching lines. Matched $matchedCount lines") + return matched.sumOf { sourceLineMap[it]!!.size } } /** * Links lines that are adjacent to already linked lines and match exactly. * @param sourceLines The source lines with some established links. */ - private fun linkAdjacentMatchingLines(sourceLines: List) { - log.debug("Starting to link adjacent matching lines") + private fun linkAdjacentMatchingLines(sourceLines: List): Int { + log.debug("Starting to link adjacent matching lines. Source lines: ${sourceLines.size}") var foundMatch = true + var matchedLines = 0 + val levenshteinDistance = LevenshteinDistance() + // Continue linking until no more matches are found while (foundMatch) { log.debug("Starting new iteration to find adjacent matches") foundMatch = false for (sourceLine in sourceLines) { val patchLine = sourceLine.matchingLine ?: continue // Skip if there's no matching line - // Check the previous line for a potential match - if (sourceLine.previousLine != null && patchLine.previousLine != null) { - val sourcePrev = sourceLine.previousLine!! - var patchPrev = patchLine.previousLine!! - while (patchPrev.type == LineType.ADD && patchPrev.previousLine != null) { - patchPrev = patchPrev.previousLine!! - } - if (sourcePrev.matchingLine == null && patchPrev.matchingLine == null) { // Skip if there's already a match - if (normalizeLine(sourcePrev.line!!) == normalizeLine(patchPrev.line!!)) { // Check if the lines match exactly - sourcePrev.matchingLine = patchPrev - patchPrev.matchingLine = sourcePrev - foundMatch = true - log.debug("Linked adjacent previous lines: Source[${sourcePrev.index}]: ${sourcePrev.line} <-> Patch[${patchPrev.index}]: ${patchPrev.line}") - } - } + var patchPrev = patchLine.previousLine ?: continue + while (patchPrev.previousLine != null && + (patchPrev.type == LineType.ADD || normalizeLine(patchPrev.line ?: "").isEmpty()) + ) { + patchPrev = patchPrev.previousLine!! } - // Check the next line for a potential match - if (sourceLine.nextLine != null && patchLine.nextLine != null) { - val sourceNext = sourceLine.nextLine!! - var patchNext = patchLine.nextLine!! - while (patchNext.type == LineType.ADD && patchNext.nextLine != null) { - patchNext = patchNext.nextLine!! - } - if (sourceNext.matchingLine == null && patchNext.matchingLine == null) { - if (normalizeLine(sourceNext.line!!) == normalizeLine(patchNext.line!!)) { - sourceNext.matchingLine = patchNext - patchNext.matchingLine = sourceNext - foundMatch = true - log.debug("Linked adjacent next lines: Source[${sourceNext.index}]: ${sourceNext.line} <-> Patch[${patchNext.index}]: ${patchNext.line}") - } + var sourcePrev = sourceLine.previousLine ?: continue + while (sourcePrev.previousLine != null && (normalizeLine(sourcePrev.line ?: "").isEmpty())) { + sourcePrev = sourcePrev.previousLine!! + } + + if (sourcePrev.matchingLine == null && patchPrev.matchingLine == null) { // Skip if there's already a match + if (isMatch(sourcePrev, patchPrev, levenshteinDistance)) { // Check if the lines match exactly + sourcePrev.matchingLine = patchPrev + patchPrev.matchingLine = sourcePrev + foundMatch = true + matchedLines++ + log.debug("Linked adjacent previous lines: Source[${sourcePrev.index}]: ${sourcePrev.line} <-> Patch[${patchPrev.index}]: ${patchPrev.line}") } } - } - } - log.debug("Finished linking adjacent matching lines") - } - // ... (other functions remain unchanged) - - /** - * Establishes links between source and patch lines based on the Levenshtein distance and proximity to already established links. - * @param sourceLines The source lines. - * @param patchLines The patch lines. - */ - private fun linkByLevenshteinDistance(sourceLines: List, patchLines: List) { - log.debug("Starting to link by Levenshtein distance") - val levenshteinDistance = LevenshteinDistance() - val maxProximity = (sourceLines.size + patchLines.size) / 10 // Increase max distance to allow more flexibility - - log.debug("Max proximity set to: $maxProximity") - // Iterate over source lines to find potential matches in the patch lines - for (sourceLine in sourceLines) { - if (sourceLine.matchingLine != null) continue // Skip lines that already have matches - var bestMatch: LineRecord? = null - var bestDistance = Int.MAX_VALUE - var bestProximity = Int.MAX_VALUE - log.trace("Processing source line: ${sourceLine.line}") - - for (patchLine in patchLines.filter { - when (it.type) { - LineType.ADD -> false // ADD lines are not matched to source lines - else -> true + var patchNext = patchLine.nextLine ?: continue + while (patchNext.nextLine != null && + (patchNext.type == LineType.ADD || normalizeLine(patchNext.line ?: "").isEmpty()) + ) { + patchNext = patchNext.nextLine!! } - }) { - if (patchLine.matchingLine != null) continue // Skip lines that already have matches - val maxDistance = minOf(bestDistance, floor(patchLine.line!!.length.toDouble() / 2).toInt()) - // Calculate the Levenshtein distance between unmatched source and patch lines - val distance = - levenshteinDistance.apply(normalizeLine(sourceLine.line!!), normalizeLine(patchLine.line!!)) - if (distance <= maxDistance) { - // Consider proximity to established links as a secondary factor - val proximity = calculateProximityDistance(sourceLine, patchLine) - if (proximity > maxProximity) continue - if (distance < bestDistance || (distance == bestDistance && proximity < bestProximity)) { - bestMatch = patchLine - bestDistance = distance - bestProximity = proximity - log.trace("Found potential match: ${patchLine.line}, distance: $distance, proximity: $proximity") - } + + var sourceNext = sourceLine.nextLine ?: continue + while (sourceNext.nextLine != null && (normalizeLine(sourceNext.line ?: "").isEmpty())) { + sourceNext = sourceNext.nextLine!! } - // Establish the best match found, if any - if (bestMatch != null) { - sourceLine.matchingLine = bestMatch - log.debug("Linked by Levenshtein distance: ${sourceLine.line} <-> ${bestMatch.line}") - bestMatch.matchingLine = sourceLine + if (sourceNext.matchingLine == null && patchNext.matchingLine == null) { + if (isMatch(sourceNext, patchNext, levenshteinDistance)) { + sourceNext.matchingLine = patchNext + patchNext.matchingLine = sourceNext + foundMatch = true + matchedLines++ + log.debug("Linked adjacent next lines: Source[${sourceNext.index}]: ${sourceNext.line} <-> Patch[${patchNext.index}]: ${patchNext.line}") + } } } } + log.debug("Finished linking adjacent matching lines. Matched $matchedLines lines") + return matchedLines } - /** - * Calculates the proximity distance between a source line and a patch line based on their distance to the nearest established link. - * @param sourceLine The source line. - * @param patchLine The patch line. - * @return The proximity distance. - */ - private fun calculateProximityDistance(sourceLine: LineRecord, patchLine: LineRecord): Int { - log.trace("Calculating proximity distance for source line: ${sourceLine.line} and patch line: ${patchLine.line}") - // Find the nearest established link in both directions for source and patch lines - var sourceDistancePrev = 0 - var sourceDistanceNext = 0 - var patchDistancePrev = 0 - var patchDistanceNext = 0 - - var currentSourceLine = sourceLine.previousLine - while (currentSourceLine != null) { - if (currentSourceLine.matchingLine != null) break - sourceDistancePrev++ - currentSourceLine = currentSourceLine.previousLine - } - - currentSourceLine = sourceLine.nextLine - while (currentSourceLine != null) { - if (currentSourceLine.matchingLine != null) break - sourceDistanceNext++ - currentSourceLine = currentSourceLine.nextLine + private fun isMatch( + sourcePrev: LineRecord, + patchPrev: LineRecord, + levenshteinDistance: LevenshteinDistance + ): Boolean { + var isMatch = normalizeLine(sourcePrev.line!!) == normalizeLine(patchPrev.line!!) + val length = max(sourcePrev.line!!.length, patchPrev.line!!.length) + if (!isMatch && length > 5) { // Check if the lines are similar using Levenshtein distance + val distance = levenshteinDistance.apply(sourcePrev.line, patchPrev.line) + log.debug("Levenshtein distance: $distance") + isMatch = distance <= (length / 3) } - - var currentPatchLine = patchLine.previousLine - while (currentPatchLine != null) { - if (currentPatchLine.matchingLine != null) break - patchDistancePrev++ - currentPatchLine = currentPatchLine.previousLine - } - - currentPatchLine = patchLine.nextLine - while (currentPatchLine != null) { - if (currentPatchLine.matchingLine != null) break - patchDistanceNext++ - currentPatchLine = currentPatchLine.nextLine - } - - // Calculate the total proximity distance as the sum of minimum distances in each direction - val proximityDistance = - minOf(sourceDistancePrev, patchDistancePrev) + minOf(sourceDistanceNext, patchDistanceNext) - log.trace("Calculated proximity distance: $proximityDistance") - return proximityDistance + return isMatch } /** @@ -332,9 +592,12 @@ object IterativePatchUtil { * @return The list of line records. */ private fun parseLines(text: String): List { - log.debug("Parsing source lines") + log.debug("Starting to parse lines") + // Create LineRecords for each line and set links between them val lines = setLinks(text.lines().mapIndexed { index, line -> LineRecord(index, line) }) - log.debug("Parsed ${lines.size} source lines") + // Calculate bracket metrics for each line + calculateLineMetrics(lines) + log.debug("Finished parsing ${lines.size} lines") return lines } @@ -343,12 +606,12 @@ object IterativePatchUtil { * @return The list with links set. */ private fun setLinks(list: List): List { - log.debug("Setting links for ${list.size} lines") - for (i in 0 until list.size) { + log.debug("Starting to set links for ${list.size} lines") + for (i in list.indices) { list[i].previousLine = if (i > 0) list[i - 1] else null list[i].nextLine = if (i < list.size - 1) list[i + 1] else null } - log.debug("Finished setting links") + log.debug("Finished setting links for ${list.size} lines") return list } @@ -358,7 +621,7 @@ object IterativePatchUtil { * @return The list of line records with types set. */ private fun parsePatchLines(text: String): List { - log.debug("Parsing patch lines") + log.debug("Starting to parse patch lines") val patchLines = setLinks(text.lines().mapIndexed { index, line -> LineRecord( index = index, @@ -371,15 +634,98 @@ object IterativePatchUtil { it.trimStart().startsWith("-") -> it.trimStart().substring(1) else -> it } - }, type = when { + }, + type = when { line.startsWith("+") -> LineType.ADD line.startsWith("-") -> LineType.DELETE else -> LineType.CONTEXT } ) - }.filter { it.line != null }) - log.debug("Parsed ${patchLines.size} patch lines") + }.filter { it.line != null }).toMutableList() + + fixPatchLineOrder(patchLines) + + calculateLineMetrics(patchLines) + log.debug("Finished parsing ${patchLines.size} patch lines") return patchLines } + private fun fixPatchLineOrder(patchLines: MutableList) { + log.debug("Starting to fix patch line order") + // Fixup: Iterate over the patch lines and look for adjacent ADD and DELETE lines; the DELETE should come first... if needed, swap them + var swapped: Boolean + do { + swapped = false + for (i in 0 until patchLines.size - 1) { + if (patchLines[i].type == LineType.ADD && patchLines[i + 1].type == LineType.DELETE) { + swapped = true + val deleteLine = patchLines[i] + val addLine = patchLines[i + 1] + // Swap records and update pointers + deleteLine.nextLine = addLine + addLine.previousLine = deleteLine + deleteLine.previousLine = addLine.previousLine + addLine.nextLine = deleteLine.nextLine + patchLines[i] = addLine + patchLines[i + 1] = deleteLine + } + } + } while (swapped) + log.debug("Finished fixing patch line order") + } + + /** + * Calculates the metrics for each line, including bracket nesting depth. + * @param lines The list of line records to process. + */ + private fun calculateLineMetrics(lines: List) { + log.debug("Starting to calculate line metrics for ${lines.size} lines") + var parenthesesDepth = 0 + var squareBracketsDepth = 0 + var curlyBracesDepth = 0 + + lines.forEach { lineRecord -> + lineRecord.line?.forEach { char -> + when (char) { + '(' -> parenthesesDepth++ + ')' -> parenthesesDepth = maxOf(0, parenthesesDepth - 1) + '[' -> squareBracketsDepth++ + ']' -> squareBracketsDepth = maxOf(0, squareBracketsDepth - 1) + '{' -> curlyBracesDepth++ + '}' -> curlyBracesDepth = maxOf(0, curlyBracesDepth - 1) + } + } + lineRecord.metrics = LineMetrics( + parenthesesDepth = parenthesesDepth, + squareBracketsDepth = squareBracketsDepth, + curlyBracesDepth = curlyBracesDepth + ) + } + log.debug("Finished calculating line metrics") + } + + fun String.lineMetrics(): LineMetrics { + var parenthesesDepth = 0 + var squareBracketsDepth = 0 + var curlyBracesDepth = 0 + + this.forEach { char -> + when (char) { + '(' -> parenthesesDepth++ + ')' -> parenthesesDepth = maxOf(0, parenthesesDepth - 1) + '[' -> squareBracketsDepth++ + ']' -> squareBracketsDepth = maxOf(0, squareBracketsDepth - 1) + '{' -> curlyBracesDepth++ + '}' -> curlyBracesDepth = maxOf(0, curlyBracesDepth - 1) + } + } + return LineMetrics( + parenthesesDepth = parenthesesDepth, + squareBracketsDepth = squareBracketsDepth, + curlyBracesDepth = curlyBracesDepth + ) + } + + private val log = LoggerFactory.getLogger(IterativePatchUtil::class.java) + } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt index ab0a3c2b..e60b8dad 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt @@ -81,7 +81,7 @@ abstract class ApplicationDirectory( log.info("Starting application with args: ${args.joinToString(", ")}") setupPlatform() init(args.contains("--server")) - ClientUtil.keyTxt = run { + if(ClientUtil.keyTxt.isEmpty()) ClientUtil.keyTxt = run { try { val encryptedData = javaClass.classLoader.getResourceAsStream("openai.key.json.kms")?.readAllBytes() ?: throw RuntimeException("Unable to load resource: ${"openai.key.json.kms"}") diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt index e768e98f..e8be2fc2 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/session/SocketManagerBase.kt @@ -3,13 +3,14 @@ package com.simiacryptus.skyenet.webui.session import com.simiacryptus.skyenet.core.platform.* import com.simiacryptus.skyenet.core.platform.ApplicationServices.clientManager import com.simiacryptus.skyenet.core.platform.AuthorizationInterface.OperationType -import com.simiacryptus.skyenet.webui.chat.ChatServer import com.simiacryptus.skyenet.webui.chat.ChatSocket import com.simiacryptus.skyenet.webui.util.MarkdownUtil import org.slf4j.LoggerFactory import java.net.URLDecoder import java.util.* +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedDeque +import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger import java.util.function.Consumer @@ -22,18 +23,15 @@ abstract class SocketManagerBase( ) ?: LinkedHashMap(), private val applicationClass: Class<*>, ) : SocketManager { - private val sockets: MutableMap = mutableMapOf() - private val sendQueues: MutableMap> = mutableMapOf() + private val sockets: MutableMap = ConcurrentHashMap() + private val sendQueues: MutableMap> = ConcurrentHashMap() private val messageVersions = HashMap() val pool get() = clientManager.getPool(session, owner) val scheduledThreadPoolExecutor get() = clientManager.getScheduledPool(session, owner, dataStorage)!! - val sendQueue = ConcurrentLinkedDeque() override fun removeSocket(socket: ChatSocket) { - synchronized(sockets) { - log.debug("Removing socket: {}", socket) - sockets.remove(socket)?.close() - } + log.debug("Removing socket: {}", socket) + sockets.remove(socket)?.close() } override fun addSocket(socket: ChatSocket, session: org.eclipse.jetty.websocket.api.Session) { @@ -45,38 +43,39 @@ abstract class SocketManagerBase( operationType = OperationType.Read ) ) throw IllegalArgumentException("Unauthorized") - synchronized(sockets) { - sockets[socket] = session - } + sockets[socket] = session } private fun publish( out: String, ) { - log.debug("Publishing message: {}", out) - synchronized(sockets) { - sockets.keys.forEach { chatSocket -> - try { - log.debug("Queueing message for socket: {}", chatSocket) - sendQueues.computeIfAbsent(chatSocket) { ConcurrentLinkedDeque() }.add(out) - } catch (e: Exception) { - log.info("Error sending message", e) - } + sockets.keys.toTypedArray().forEach { chatSocket -> + try { + val deque = sendQueues.computeIfAbsent(chatSocket) { ConcurrentLinkedDeque() } + deque.add(out) pool.submit { try { - val deque = sendQueues[chatSocket]!! - synchronized(deque) { - while (true) { - val msg = deque.poll() ?: break - log.debug("Sending message: {} to socket: {}", msg, chatSocket) - chatSocket.remote.sendString(msg) + if (deque.isEmpty()) return@submit + ioPool.submit { + try { + while (deque.isNotEmpty()) { + val msg = deque.poll() ?: break + log.info("Sending message: {} to socket: {}", msg, chatSocket) + synchronized(chatSocket) { + chatSocket.remote.sendString(msg) + } + } + chatSocket.remote.flush() + } catch (e: Exception) { + log.info("Error sending message", e) } - chatSocket.remote.flush() } } catch (e: Exception) { log.info("Error sending message", e) } } + } catch (e: Exception) { + log.info("Error sending message", e) } } } @@ -120,39 +119,28 @@ abstract class SocketManagerBase( log.debug("Sending message: {}", out) val split = out.split(',', ignoreCase = false, limit = 2) val messageID = split[0] - val newValue = split[1] + var newValue = split[1] + if (newValue == "null") { + newValue = "" + } if (setMessage(messageID, newValue) < 0) { log.debug("Skipping duplicate message - Key: {}, Value: {} bytes", messageID, newValue.length) return } - if (sendQueue.contains(messageID)) { - log.debug("Skipping already queued message - Key: {}, Value: {} bytes", messageID, newValue.length) - return - } - if (0 == out.length) { + if (out.isEmpty()) { log.debug("Skipping empty message - Key: {}, Value: {} bytes", messageID, newValue.length) return } - log.debug("Queue Send Msg: {} - {} - {} bytes", session, messageID, out.length) - sendQueue.add(messageID) - scheduledThreadPoolExecutor.schedule( - { - try { - while (sendQueue.isNotEmpty()) { - val messageID = sendQueue.poll() ?: return@schedule - val ver = messageVersions[messageID]?.get() - val v = messageStates[messageID] - log.debug("Wire Send Msg: {} - {} - {} - {} bytes", session, messageID, ver, v?.length) - publish(messageID + "," + ver + "," + v) - } - } catch (e: Exception) { - log.debug("$session - $out", e) - } - }, - 50, java.util.concurrent.TimeUnit.MILLISECONDS - ) + try { + val ver = messageVersions[messageID]?.get() + val v = messageStates[messageID] + log.info("Publish Msg: {} - {} - {} - {} bytes", session, messageID, ver, v?.length) + publish("$messageID,$ver,$v") + } catch (e: Exception) { + log.info("$session - $out", e) + } } catch (e: Exception) { - log.debug("$session - $out", e) + log.info("$session - $out", e) } } @@ -261,8 +249,9 @@ abstract class SocketManagerBase( ) companion object { - private val log = LoggerFactory.getLogger(ChatServer::class.java) + private val log = LoggerFactory.getLogger(SocketManagerBase::class.java) + private val ioPool = Executors.newCachedThreadPool() private val range1 = ('a'..'y').toList().toTypedArray() private val range2 = range1 + 'z' fun randomID(root: Boolean = true): String { diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/FilePatchTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/FilePatchTestApp.kt new file mode 100644 index 00000000..b2f55d6f --- /dev/null +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/test/FilePatchTestApp.kt @@ -0,0 +1,66 @@ +package com.simiacryptus.skyenet.webui.test + +import com.simiacryptus.diff.addApplyFileDiffLinks +import com.simiacryptus.jopenai.API +import com.simiacryptus.jopenai.ApiModel +import com.simiacryptus.jopenai.OpenAIClient +import com.simiacryptus.skyenet.core.actors.CodingActor +import com.simiacryptus.skyenet.core.actors.CodingActor.Companion.indent +import com.simiacryptus.skyenet.core.platform.ApplicationServices +import com.simiacryptus.skyenet.core.platform.AuthorizationInterface.OperationType +import com.simiacryptus.skyenet.core.platform.ClientManager +import com.simiacryptus.skyenet.core.platform.Session +import com.simiacryptus.skyenet.core.platform.User +import com.simiacryptus.skyenet.webui.application.ApplicationInterface +import com.simiacryptus.skyenet.webui.application.ApplicationServer +import com.simiacryptus.skyenet.webui.application.ApplicationSocketManager +import com.simiacryptus.skyenet.webui.session.SocketManager +import com.simiacryptus.skyenet.webui.util.MarkdownUtil.renderMarkdown +import org.slf4j.LoggerFactory +import java.awt.Desktop +import java.nio.file.Files +import java.util.* + +open class FilePatchTestApp( + applicationName: String = "FilePatchTestApp", + val api: API = OpenAIClient() +) : ApplicationServer( + applicationName = applicationName, + path = "/codingActorTest", +) { + override fun newSession(user: User?, session: Session): SocketManager { + val socketManager = super.newSession(user, session) + val ui = (socketManager as ApplicationSocketManager).applicationInterface + val task = ui.newTask(true) + + val source = """ + |fun main(args: Array) { + | println(${'"'}"" + | Hello, World! + | ${'"'}"") + |} + """.trimMargin() + val sourceFile = Files.createTempFile("source", ".txt").toFile() + sourceFile.writeText(source) + sourceFile.deleteOnExit() + //Desktop.getDesktop().open(sourceFile) + + val patch = """ + |# ${sourceFile.name} + | + |```diff + |-Hello, World! + |+Goodbye, World! + |``` + """.trimMargin() + val newPatch = socketManager.addApplyFileDiffLinks(sourceFile.toPath().parent, patch, {}, ui, api) + task.complete(renderMarkdown(newPatch, ui = ui)) + + return socketManager + } + + companion object { + private val log = LoggerFactory.getLogger(FilePatchTestApp::class.java) + } + +} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt index c07fa65b..10561d55 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/util/EncryptFiles.kt @@ -1,6 +1,7 @@ package com.simiacryptus.skyenet.webui.util import com.simiacryptus.skyenet.core.platform.ApplicationServices +import java.io.File import java.nio.file.Files import java.nio.file.Paths @@ -8,8 +9,9 @@ object EncryptFiles { @JvmStatic fun main(args: Array) { - "".encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") - .write("""C:\Users\andre\code\SkyenetApps\src\main\resources\patreon.json.kms""") + File("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json""") + .readText().encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") + .write("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json.kms""") } } @@ -18,5 +20,5 @@ fun String.write(outpath: String) { Files.write(Paths.get(outpath), toByteArray()) } -fun String.encrypt(keyId: String) = ApplicationServices.cloud!!.encrypt(encodeToByteArray(), keyId) +fun String.encrypt(keyId: String) = ApplicationServices.cloud?.encrypt(encodeToByteArray(), keyId) ?: throw RuntimeException("Unable to encrypt data") \ No newline at end of file diff --git a/webui/src/main/resources/application/chat.js b/webui/src/main/resources/application/chat.js index 120f0550..37d54fd0 100644 --- a/webui/src/main/resources/application/chat.js +++ b/webui/src/main/resources/application/chat.js @@ -1,41 +1,44 @@ - let socket; +let reconnectAttempts = 0; +const MAX_RECONNECT_DELAY = 30000; // Maximum delay of 30 seconds -function send(message) { +export function send(message) { console.log('Sending message:', message); if (socket.readyState !== 1) { - throw new Error('WebSocket is not open'); + console.error('WebSocket is not open. Message not sent:', message); + return false; } socket.send(message); + return true; } -function connect(sessionId, customReceiveFunction) { +export function connect(sessionId, customReceiveFunction) { + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const host = window.location.hostname; const port = window.location.port; - let path = window.location.pathname; - let strings = path.split('/'); - if(strings.length >= 2 && strings[1] !== '' && strings[1] !== 'index.html') { - path = '/' + strings[1] + '/'; - } else { - path = '/'; - } + const path = getWebSocketPath(); socket = new WebSocket(`${protocol}//${host}:${port}${path}ws?sessionId=${sessionId}`); socket.addEventListener('open', (event) => { console.log('WebSocket connected:', event); showDisconnectedOverlay(false); + reconnectAttempts = 0; + }); + socket.addEventListener('message', (event) => { + if (customReceiveFunction) { + customReceiveFunction(event); + } else { + onWebSocketText(event); + } }); - socket.addEventListener('message', customReceiveFunction || onWebSocketText); socket.addEventListener('close', (event) => { console.log('WebSocket closed:', event); showDisconnectedOverlay(true); - setTimeout(() => { - connect(getSessionId(), customReceiveFunction); - }, 3000); + reconnect(sessionId, customReceiveFunction); }); socket.addEventListener('error', (event) => { @@ -43,9 +46,43 @@ function connect(sessionId, customReceiveFunction) { }); } +function getWebSocketPath() { + const path = window.location.pathname; + const strings = path.split('/'); + return (strings.length >= 2 && strings[1] !== '' && strings[1] !== 'index.html') + ? '/' + strings[1] + '/' + : '/'; +} + +function reconnect(sessionId, customReceiveFunction) { + const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), MAX_RECONNECT_DELAY); + console.log(`Attempting to reconnect in ${delay}ms...`); + setTimeout(() => { + connect(sessionId, customReceiveFunction); + reconnectAttempts++; + }, delay); +} + function showDisconnectedOverlay(show) { - const elements = document.getElementsByClassName('ws-control'); - for (let i = 0; i < elements.length; i++) { - elements[i].disabled = show; - } + document.querySelectorAll('.ws-control').forEach(element => { + element.disabled = show; + }); +} + +// Implement a message queue to handle potential disconnections +let messageQueue = []; + +export function queueMessage(message) { + + messageQueue.push(message); + processMessageQueue(); } + +function processMessageQueue() { + if (socket.readyState === WebSocket.OPEN) { + while (messageQueue.length > 0) { + const message = messageQueue.shift(); + send(message); + } + } +} \ No newline at end of file diff --git a/webui/src/main/resources/application/functions.js b/webui/src/main/resources/application/functions.js index 034f5664..ee238f8e 100644 --- a/webui/src/main/resources/application/functions.js +++ b/webui/src/main/resources/application/functions.js @@ -223,13 +223,15 @@ function applyToAllSvg() { } function substituteMessages(outerMessageId, messageDiv) { - Object.entries(messageMap).forEach(([innerMessageId, content]) => { - if (outerMessageId !== innerMessageId && messageDiv) messageDiv.querySelectorAll('[id="' + innerMessageId + '"]').forEach((element) => { - if (element.innerHTML !== content) { - //console.log("Substituting message with id " + innerMessageId + " and content " + content); - element.innerHTML = content; - substituteMessages(innerMessageId, element); - } + Object.entries(window.messageMap) + .filter(([innerMessageId, content]) => innerMessageId.startsWith("z")) + .forEach(([innerMessageId, content]) => { + if (outerMessageId !== innerMessageId && messageDiv) messageDiv.querySelectorAll('[id="' + innerMessageId + '"]').forEach((element) => { + if (element.innerHTML !== content) { + //console.log("Substituting message with id " + innerMessageId + " and content " + content); + element.innerHTML = content; + substituteMessages(innerMessageId, element); + } + }); }); - }); -} +} \ No newline at end of file diff --git a/webui/src/main/resources/application/index.html b/webui/src/main/resources/application/index.html index 83bdfb9e..088e3fb3 100644 --- a/webui/src/main/resources/application/index.html +++ b/webui/src/main/resources/application/index.html @@ -3,8 +3,8 @@ WebSocket Client - - + + @@ -14,14 +14,14 @@ rel="stylesheet"/> - + - - + + @@ -99,13 +99,13 @@
- +
-