From b4feb0351d98829fd695f5dd3d860f78fd75c606 Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Sun, 3 Nov 2024 19:49:08 -0500 Subject: [PATCH] 1.2.16 (#117) * 1.2.16 * 1.2.16 * 1.2.16 --- INTERPRETER_MODULES_DOCUMENTATION.md | 5 +- build.gradle.kts | 6 +- core/build.gradle.kts | 246 +- .../skyenet/core/OutputInterceptor.kt | 128 +- .../skyenet/core/actors/ActorSystem.kt | 18 +- .../skyenet/core/actors/BaseActor.kt | 34 +- .../skyenet/core/actors/CodingActor.kt | 782 +-- .../skyenet/core/actors/ImageActor.kt | 146 +- .../skyenet/core/actors/ImageResponse.kt | 4 +- .../skyenet/core/actors/ParsedActor.kt | 228 +- .../skyenet/core/actors/ParsedResponse.kt | 30 +- .../skyenet/core/actors/SimpleActor.kt | 54 +- .../skyenet/core/actors/SpeechResponse.kt | 2 +- .../skyenet/core/actors/TextToSpeechActor.kt | 88 +- .../core/platform/ApplicationServices.kt | 90 +- .../skyenet/core/platform/AwsPlatform.kt | 170 +- .../skyenet/core/platform/ClientManager.kt | 300 +- .../skyenet/core/platform/Session.kt | 84 +- .../platform/file/AuthenticationManager.kt | 20 +- .../platform/file/AuthorizationManager.kt | 134 +- .../skyenet/core/platform/file/DataStorage.kt | 330 +- .../core/platform/file/MetadataStorage.kt | 238 +- .../core/platform/file/UsageManager.kt | 352 +- .../core/platform/file/UserSettingsManager.kt | 60 +- .../core/platform/hsql/HSQLMetadataStorage.kt | 258 +- .../core/platform/hsql/HSQLUsageManager.kt | 250 +- .../model/ApplicationServicesConfig.kt | 20 +- .../platform/model/AuthenticationInterface.kt | 14 +- .../platform/model/AuthorizationInterface.kt | 30 +- .../platform/model/CloudPlatformInterface.kt | 26 +- .../model/MetadataStorageInterface.kt | 16 +- .../core/platform/model/StorageInterface.kt | 174 +- .../core/platform/model/UsageInterface.kt | 86 +- .../skyenet/core/platform/model/User.kt | 30 +- .../platform/model/UserSettingsInterface.kt | 12 +- .../test/AuthenticationInterfaceTest.kt | 74 +- .../test/AuthorizationInterfaceTest.kt | 22 +- .../test/MetadataStorageInterfaceTest.kt | 310 +- .../platform/test/StorageInterfaceTest.kt | 416 +- .../skyenet/core/platform/test/UsageTest.kt | 280 +- .../core/platform/test/UserSettingsTest.kt | 104 +- .../core/util/ClasspathRelationships.kt | 636 +-- .../skyenet/core/util/CommonRoot.kt | 36 +- .../simiacryptus/skyenet/core/util/Ears.kt | 182 +- .../skyenet/core/util/FunctionWrapper.kt | 294 +- .../skyenet/core/util/GetModuleRootForFile.kt | 20 +- .../skyenet/core/util/LoggingInterceptor.kt | 106 +- .../skyenet/core/util/MultiExeption.kt | 2 +- .../skyenet/core/util/RuleTreeBuilder.kt | 214 +- .../skyenet/core/util/Selenium.kt | 10 +- .../skyenet/core/util/StringSplitter.kt | 52 +- .../skyenet/interpreter/Interpreter.kt | 64 +- .../interpreter/InterpreterTestBase.kt | 156 +- .../skyenet/core/util/RuleTreeBuilderTest.kt | 74 +- gradle.properties | 2 +- groovy/build.gradle.kts | 204 +- .../skyenet/groovy/GroovyInterpreter.kt | 52 +- .../skyenet/groovy/GroovyInterpreterTest.kt | 4 +- kotlin/build.gradle.kts | 220 +- .../skyenet/kotlin/KotlinInterpreter.kt | 218 +- .../skyenet/kotlin/KotlinInterpreterTest.kt | 58 +- scala/build.gradle.kts | 192 +- webui/build.gradle.kts | 308 +- .../simiacryptus/diff/AddApplyDiffLinks.kt | 180 +- .../diff/AddApplyFileDiffLinks.kt | 986 ++-- .../diff/AddShellExecutionLinks.kt | 56 +- .../com/simiacryptus/diff/ApxPatchUtil.kt | 206 +- .../com/simiacryptus/diff/DiffMatchPatch.kt | 4546 ++++++++--------- .../kotlin/com/simiacryptus/diff/DiffUtil.kt | 266 +- .../simiacryptus/diff/FileValidationUtils.kt | 297 +- .../simiacryptus/diff/IterativePatchUtil.kt | 1398 ++--- .../com/simiacryptus/diff/PatchResult.kt | 4 +- .../com/simiacryptus/skyenet/AgentPatterns.kt | 72 +- .../com/simiacryptus/skyenet/Discussable.kt | 380 +- .../com/simiacryptus/skyenet/Retryable.kt | 52 +- .../com/simiacryptus/skyenet/TabbedDisplay.kt | 136 +- .../skyenet/apps/code/CodingAgent.kt | 520 +- .../skyenet/apps/code/ShellToolAgent.kt | 848 +-- .../skyenet/apps/general/AutoPlanChatApp.kt | 826 +-- .../skyenet/apps/general/CmdPatchApp.kt | 250 +- .../skyenet/apps/general/CommandPatchApp.kt | 114 +- .../skyenet/apps/general/PatchApp.kt | 474 +- .../skyenet/apps/general/PlanAheadApp.kt | 158 +- .../skyenet/apps/general/PlanChatApp.kt | 260 +- .../skyenet/apps/general/StressTestApp.kt | 120 +- .../skyenet/apps/general/WebDevApp.kt | 880 ++-- .../skyenet/apps/meta/ActorDesigner.kt | 34 +- .../skyenet/apps/meta/CodingActorDesigner.kt | 26 +- .../skyenet/apps/meta/DetailDesigner.kt | 68 +- .../skyenet/apps/meta/FlowStepDesigner.kt | 44 +- .../skyenet/apps/meta/HighLevelDesigner.kt | 10 +- .../skyenet/apps/meta/ImageActorDesigner.kt | 26 +- .../skyenet/apps/meta/MetaAgentApp.kt | 979 ++-- .../skyenet/apps/meta/ParsedActorDesigner.kt | 26 +- .../skyenet/apps/meta/SimpleActorDesigner.kt | 26 +- .../skyenet/apps/parse/CodeParsingModel.kt | 118 +- .../skyenet/apps/parse/DocumentParserApp.kt | 436 +- .../apps/parse/DocumentParsingModel.kt | 208 +- .../skyenet/apps/parse/DocumentRecord.kt | 163 +- .../skyenet/apps/parse/PDFReader.kt | 30 +- .../skyenet/apps/parse/ParsingModel.kt | 34 +- .../skyenet/apps/parse/TextReader.kt | 70 +- .../skyenet/apps/plan/AbstractTask.kt | 86 +- .../skyenet/apps/plan/CommandAutoFixTask.kt | 230 +- .../skyenet/apps/plan/ForeachTask.kt | 116 +- .../skyenet/apps/plan/GoogleSearchTask.kt | 13 +- .../skyenet/apps/plan/PlanCoordinator.kt | 464 +- .../skyenet/apps/plan/PlanProcessingState.kt | 16 +- .../skyenet/apps/plan/PlanSettings.kt | 241 +- .../skyenet/apps/plan/PlanUtil.kt | 288 +- .../simiacryptus/skyenet/apps/plan/Planner.kt | 190 +- .../skyenet/apps/plan/PlanningTask.kt | 232 +- .../skyenet/apps/plan/RunShellCommandTask.kt | 194 +- .../skyenet/apps/plan/SearchTask.kt | 133 + .../apps/plan/TaskBreakdownWithPrompt.kt | 6 +- .../skyenet/apps/plan/TaskType.kt | 226 +- .../apps/plan/file/AbstractAnalysisTask.kt | 120 +- .../apps/plan/file/AbstractFileTask.kt | 92 +- .../apps/plan/file/CodeOptimizationTask.kt | 62 +- .../skyenet/apps/plan/file/CodeReviewTask.kt | 68 +- .../apps/plan/file/DocumentationTask.kt | 200 +- .../apps/plan/file/FileModificationTask.kt | 196 +- .../skyenet/apps/plan/file/InquiryTask.kt | 228 +- .../apps/plan/file/PerformanceAnalysisTask.kt | 68 +- .../skyenet/apps/plan/file/RefactorTask.kt | 58 +- .../skyenet/apps/plan/file/SearchTask.kt | 134 - .../apps/plan/file/SecurityAuditTask.kt | 60 +- .../apps/plan/file/TestGenerationTask.kt | 58 +- .../plan/knowledge/EmbeddingSearchTask.kt | 352 +- .../plan/knowledge/KnowledgeIndexingTask.kt | 110 + .../plan/knowledge/WebSearchAndIndexTask.kt | 155 + .../skyenet/interpreter/ProcessInterpreter.kt | 78 +- .../simiacryptus/skyenet/util/EncryptFiles.kt | 16 +- .../simiacryptus/skyenet/util/MarkdownUtil.kt | 318 +- .../com/simiacryptus/skyenet/util/OpenAPI.kt | 118 +- .../simiacryptus/skyenet/util/Selenium2S3.kt | 822 +-- .../skyenet/util/TensorflowProjector.kt | 151 +- .../skyenet/webui/application/AppInfoData.kt | 16 +- .../webui/application/ApplicationDirectory.kt | 386 +- .../webui/application/ApplicationInterface.kt | 68 +- .../webui/application/ApplicationServer.kt | 318 +- .../application/ApplicationSocketManager.kt | 62 +- .../skyenet/webui/chat/ChatServer.kt | 106 +- .../skyenet/webui/chat/ChatSocket.kt | 46 +- .../skyenet/webui/chat/ChatSocketManager.kt | 124 +- .../skyenet/webui/servlet/ApiKeyServlet.kt | 324 +- .../skyenet/webui/servlet/AppInfoServlet.kt | 12 +- .../webui/servlet/CancelThreadsServlet.kt | 94 +- .../skyenet/webui/servlet/CorsFilter.kt | 56 +- .../webui/servlet/DeleteSessionServlet.kt | 66 +- .../skyenet/webui/servlet/FileServlet.kt | 366 +- .../skyenet/webui/servlet/LogoutServlet.kt | 20 +- .../webui/servlet/NewSessionServlet.kt | 12 +- .../skyenet/webui/servlet/OAuthBase.kt | 2 +- .../skyenet/webui/servlet/OAuthGoogle.kt | 174 +- .../skyenet/webui/servlet/ProxyHttpServlet.kt | 344 +- .../webui/servlet/SessionFileServlet.kt | 24 +- .../skyenet/webui/servlet/SessionIdFilter.kt | 44 +- .../webui/servlet/SessionListServlet.kt | 40 +- .../webui/servlet/SessionSettingsServlet.kt | 88 +- .../webui/servlet/SessionShareServlet.kt | 224 +- .../webui/servlet/SessionThreadsServlet.kt | 52 +- .../skyenet/webui/servlet/UsageServlet.kt | 60 +- .../skyenet/webui/servlet/UserInfoServlet.kt | 18 +- .../webui/servlet/UserSettingsServlet.kt | 132 +- .../skyenet/webui/servlet/WelcomeServlet.kt | 164 +- .../skyenet/webui/servlet/ZipServlet.kt | 64 +- .../skyenet/webui/session/SessionTask.kt | 324 +- .../skyenet/webui/session/SocketManager.kt | 8 +- .../webui/session/SocketManagerBase.kt | 447 +- .../skyenet/webui/test/CodingActorTestApp.kt | 86 +- .../skyenet/webui/test/FilePatchTestApp.kt | 50 +- .../skyenet/webui/test/ImageActorTestApp.kt | 78 +- .../skyenet/webui/test/ParsedActorTestApp.kt | 58 +- .../skyenet/webui/test/SimpleActorTestApp.kt | 70 +- .../com/simiacryptus/diff/ApxPatchUtilTest.kt | 32 +- .../com/simiacryptus/diff/DiffUtilTest.kt | 100 +- .../diff/IterativePatchUtilTest.kt | 346 +- .../skyenet/webui/ActorTestAppServer.kt | 224 +- .../skyenet/webui/util/MarkdownUtilTest.kt | 46 +- 180 files changed, 17837 insertions(+), 17539 deletions(-) create mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/SearchTask.kt delete mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SearchTask.kt create mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/KnowledgeIndexingTask.kt create mode 100644 webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/WebSearchAndIndexTask.kt diff --git a/INTERPRETER_MODULES_DOCUMENTATION.md b/INTERPRETER_MODULES_DOCUMENTATION.md index 67524003..c2f419b9 100644 --- a/INTERPRETER_MODULES_DOCUMENTATION.md +++ b/INTERPRETER_MODULES_DOCUMENTATION.md @@ -1,6 +1,7 @@ ### Common Characteristics All interpreter modules share the following characteristics: + 1. They implement the `Interpreter` interface, ensuring a consistent API across different language interpreters. 2. They support the addition of predefined variables, allowing for context to be passed into the executed code. 3. They provide methods for running code, validating syntax, and retrieving language-specific information. @@ -9,8 +10,10 @@ All interpreter modules share the following characteristics: ### Integration with SkyeNet These interpreter modules play a crucial role in SkyeNet's multi-language support feature. They allow the AI-powered system to: + 1. Execute code snippets in different languages as part of task processing. 2. Validate code syntax before execution, enhancing error handling and user feedback. 3. Integrate language-specific features and libraries into the SkyeNet workflow. -By providing a unified interface through the `Interpreter` interface, SkyeNet can seamlessly work with multiple programming languages, expanding its capabilities and flexibility in handling diverse coding tasks and applications. \ No newline at end of file +By providing a unified interface through the `Interpreter` interface, SkyeNet can seamlessly work with multiple programming languages, expanding its capabilities and flexibility in +handling diverse coding tasks and applications. \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 84c66e27..c450045a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -3,7 +3,7 @@ group = properties("libraryGroup") version = properties("libraryVersion") tasks { - wrapper { - gradleVersion = properties("gradleVersion") - } + wrapper { + gradleVersion = properties("gradleVersion") + } } diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 44df008a..76679cf2 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -6,24 +6,24 @@ group = properties("libraryGroup") version = properties("libraryVersion") plugins { - java - `java-library` - id("org.jetbrains.kotlin.jvm") version "2.0.20" - `maven-publish` - id("signing") + java + `java-library` + id("org.jetbrains.kotlin.jvm") version "2.0.20" + `maven-publish` + id("signing") } repositories { - mavenCentral { - metadataSources { - mavenPom() - artifact() - } + mavenCentral { + metadataSources { + mavenPom() + artifact() } + } } kotlin { - jvmToolchain(11) + jvmToolchain(11) } val junit_version = "5.10.1" @@ -33,151 +33,151 @@ val hsqldb_version = "2.7.2" dependencies { - implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.1.12") - implementation(group = "org.hsqldb", name = "hsqldb", version = hsqldb_version) + implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.1.12") + implementation(group = "org.hsqldb", name = "hsqldb", version = hsqldb_version) - implementation("org.apache.commons:commons-text:1.11.0") + implementation("org.apache.commons:commons-text:1.11.0") - implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") - implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - implementation(group = "com.google.guava", name = "guava", version = "32.1.3-jre") - implementation(group = "com.google.code.gson", name = "gson", version = "2.10.1") - implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.2.3") - implementation("org.eclipse.jetty.toolchain:jetty-jakarta-servlet-api:5.0.2") + implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") + implementation(group = "com.google.guava", name = "guava", version = "32.1.3-jre") + implementation(group = "com.google.code.gson", name = "gson", version = "2.10.1") + implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.2.3") + implementation("org.eclipse.jetty.toolchain:jetty-jakarta-servlet-api:5.0.2") - implementation(group = "com.fasterxml.jackson.core", name = "jackson-databind", version = jackson_version) - implementation(group = "com.fasterxml.jackson.core", name = "jackson-annotations", version = jackson_version) - implementation(group = "com.fasterxml.jackson.module", name = "jackson-module-kotlin", version = jackson_version) + implementation(group = "com.fasterxml.jackson.core", name = "jackson-databind", version = jackson_version) + implementation(group = "com.fasterxml.jackson.core", name = "jackson-annotations", version = jackson_version) + implementation(group = "com.fasterxml.jackson.module", name = "jackson-module-kotlin", version = jackson_version) - compileOnly("org.ow2.asm:asm:9.6") - compileOnly(kotlin("stdlib")) - compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") + compileOnly("org.ow2.asm:asm:9.6") + compileOnly(kotlin("stdlib")) + compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") - testImplementation(kotlin("stdlib")) - testImplementation(kotlin("script-runtime")) + testImplementation(kotlin("stdlib")) + testImplementation(kotlin("script-runtime")) - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version) - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version) - compileOnly(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version) - compileOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version) + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version) + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version) + compileOnly(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version) + compileOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version) - compileOnly(platform("software.amazon.awssdk:bom:2.21.29")) - compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") - testImplementation(platform("software.amazon.awssdk:bom:2.21.29")) - testImplementation(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") + compileOnly(platform("software.amazon.awssdk:bom:2.21.29")) + compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") + testImplementation(platform("software.amazon.awssdk:bom:2.21.29")) + testImplementation(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") - compileOnly(group = "ch.qos.logback", name = "logback-classic", version = logback_version) - compileOnly(group = "ch.qos.logback", name = "logback-core", version = logback_version) - testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version) - testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version) + compileOnly(group = "ch.qos.logback", name = "logback-classic", version = logback_version) + compileOnly(group = "ch.qos.logback", name = "logback-core", version = logback_version) + testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version) + testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version) - testImplementation(group = "org.mockito", name = "mockito-core", version = "5.7.0") + testImplementation(group = "org.mockito", name = "mockito-core", version = "5.7.0") } tasks { - compileKotlin { - compilerOptions { - javaParameters.set(true) - } + compileKotlin { + compilerOptions { + javaParameters.set(true) } - compileTestKotlin { - compilerOptions { - javaParameters.set(true) - } + } + compileTestKotlin { + compilerOptions { + javaParameters.set(true) } - test { - useJUnitPlatform() - systemProperty("surefire.useManifestOnlyJar", "false") - testLogging { - events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - } - jvmArgs( - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED" - ) + } + test { + useJUnitPlatform() + systemProperty("surefire.useManifestOnlyJar", "false") + testLogging { + events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL } + jvmArgs( + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", "java.base/java.util=ALL-UNNAMED", + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + } } val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") - from(tasks.javadoc) + archiveClassifier.set("javadoc") + from(tasks.javadoc) } val sourcesJar by tasks.registering(Jar::class) { - archiveClassifier.set("sources") - from(sourceSets.main.get().allSource) + archiveClassifier.set("sources") + from(sourceSets.main.get().allSource) } publishing { - publications { - create("mavenJava") { - artifactId = "core" - from(components["java"]) - artifact(sourcesJar.get()) - artifact(javadocJar.get()) - versionMapping { - usage("java-api") { - fromResolutionOf("runtimeClasspath") - } - usage("java-runtime") { - fromResolutionResult() - } - } - pom { - name.set("SkyeNet Core Components") - description.set("A very helpful puppy") - url.set("https://github.com/SimiaCryptus/SkyeNet") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id.set("acharneski") - name.set("Andrew Charneski") - email.set("acharneski@gmail.com") - } - } - scm { - connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") - developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") - url.set("https://github.com/SimiaCryptus/SkyeNet") - } - } + publications { + create("mavenJava") { + artifactId = "core" + from(components["java"]) + artifact(sourcesJar.get()) + artifact(javadocJar.get()) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") } - } - repositories { - maven { - val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" - val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" - url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - credentials { - username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") - ?: properties("ossrhUsername") - password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") - ?: properties("ossrhPassword") - } + usage("java-runtime") { + fromResolutionResult() } - } - if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { - signing { - sign(publications["mavenJava"]) + } + pom { + name.set("SkyeNet Core Components") + description.set("A very helpful puppy") + url.set("https://github.com/SimiaCryptus/SkyeNet") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } } + developers { + developer { + id.set("acharneski") + name.set("Andrew Charneski") + email.set("acharneski@gmail.com") + } + } + scm { + connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") + developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") + url.set("https://github.com/SimiaCryptus/SkyeNet") + } + } + } + } + repositories { + maven { + val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" + url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + credentials { + username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") + ?: properties("ossrhUsername") + password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") + ?: properties("ossrhPassword") + } } + } + if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { + signing { + sign(publications["mavenJava"]) + } + } } if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) { - apply() - configure { - useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - sign(configurations.archives.get()) - } + apply() + configure { + useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + sign(configurations.archives.get()) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/OutputInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/OutputInterceptor.kt index efe85ca1..4c733d33 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/OutputInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/OutputInterceptor.kt @@ -7,86 +7,86 @@ import java.util.* import java.util.concurrent.atomic.AtomicBoolean object OutputInterceptor { - private val originalOut: PrintStream = System.out - private val originalErr: PrintStream = System.err - private val isSetup = AtomicBoolean(false) - private val globalStreamLock = Any() + private val originalOut: PrintStream = System.out + private val originalErr: PrintStream = System.err + private val isSetup = AtomicBoolean(false) + private val globalStreamLock = Any() - fun setupInterceptor() { - if (isSetup.getAndSet(true)) return - System.setOut(PrintStream(OutputStreamRouter(originalOut))) - System.setErr(PrintStream(OutputStreamRouter(originalErr))) - } + fun setupInterceptor() { + if (isSetup.getAndSet(true)) return + System.setOut(PrintStream(OutputStreamRouter(originalOut))) + System.setErr(PrintStream(OutputStreamRouter(originalErr))) + } - private val globalStream = ByteArrayOutputStream() + private val globalStream = ByteArrayOutputStream() - private val threadLocalBuffer = WeakHashMap() + private val threadLocalBuffer = WeakHashMap() - private fun getThreadOutputStream(): ByteArrayOutputStream { - val currentThread = Thread.currentThread() - synchronized(threadLocalBuffer) { - return threadLocalBuffer.getOrPut(currentThread) { ByteArrayOutputStream() } - } + private fun getThreadOutputStream(): ByteArrayOutputStream { + val currentThread = Thread.currentThread() + synchronized(threadLocalBuffer) { + return threadLocalBuffer.getOrPut(currentThread) { ByteArrayOutputStream() } } + } - fun getThreadOutput(): String { - val outputStream = getThreadOutputStream() - try { - outputStream.flush() - } catch (e: IOException) { - throw RuntimeException(e) - } - return outputStream.toString() + fun getThreadOutput(): String { + val outputStream = getThreadOutputStream() + try { + outputStream.flush() + } catch (e: IOException) { + throw RuntimeException(e) } + return outputStream.toString() + } - fun clearThreadOutput() { - getThreadOutputStream().reset() - } + fun clearThreadOutput() { + getThreadOutputStream().reset() + } - fun getGlobalOutput(): String { - synchronized(globalStreamLock) { - return globalStream.toString() - } + fun getGlobalOutput(): String { + synchronized(globalStreamLock) { + return globalStream.toString() } + } - fun clearGlobalOutput() { - synchronized(globalStreamLock) { - globalStream.reset() - } + fun clearGlobalOutput() { + synchronized(globalStreamLock) { + globalStream.reset() } + } - private class OutputStreamRouter(private val originalStream: PrintStream) : ByteArrayOutputStream() { - private val maxGlobalBuffer = 8 * 1024 * 1024 - private val maxThreadBuffer = 1024 * 1024 + private class OutputStreamRouter(private val originalStream: PrintStream) : ByteArrayOutputStream() { + private val maxGlobalBuffer = 8 * 1024 * 1024 + private val maxThreadBuffer = 1024 * 1024 - override fun write(b: Int) { - originalStream.write(b) - synchronized(globalStreamLock) { - if (globalStream.size() > maxGlobalBuffer) { - globalStream.reset() - } - globalStream.write(b) - } - val threadOutputStream = getThreadOutputStream() - if (threadOutputStream.size() > maxThreadBuffer) { - threadOutputStream.reset() - } - threadOutputStream.write(b) + override fun write(b: Int) { + originalStream.write(b) + synchronized(globalStreamLock) { + if (globalStream.size() > maxGlobalBuffer) { + globalStream.reset() } + globalStream.write(b) + } + val threadOutputStream = getThreadOutputStream() + if (threadOutputStream.size() > maxThreadBuffer) { + threadOutputStream.reset() + } + threadOutputStream.write(b) + } - override fun write(b: ByteArray, off: Int, len: Int) { - originalStream.write(b, off, len) - synchronized(globalStreamLock) { - if (globalStream.size() > maxGlobalBuffer) { - globalStream.reset() - } - globalStream.write(b, off, len) - } - val threadOutputStream = getThreadOutputStream() - if (threadOutputStream.size() > maxThreadBuffer) { - threadOutputStream.reset() - } - threadOutputStream.write(b, off, len) + override fun write(b: ByteArray, off: Int, len: Int) { + originalStream.write(b, off, len) + synchronized(globalStreamLock) { + if (globalStream.size() > maxGlobalBuffer) { + globalStream.reset() } + globalStream.write(b, off, len) + } + val threadOutputStream = getThreadOutputStream() + if (threadOutputStream.size() > maxThreadBuffer) { + threadOutputStream.reset() + } + threadOutputStream.write(b, off, len) } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt index 102ed9bc..a29b4895 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ActorSystem.kt @@ -6,20 +6,20 @@ import com.simiacryptus.skyenet.core.platform.model.StorageInterface import com.simiacryptus.skyenet.core.platform.model.User open class ActorSystem>( - val actors: Map>, - dataStorage: StorageInterface, - user: User?, - session: Session + val actors: Map>, + dataStorage: StorageInterface, + user: User?, + session: Session ) : PoolSystem(dataStorage, user, session) { - fun getActor(actor: T) = actors.get(actor.name)!! + fun getActor(actor: T) = actors.get(actor.name)!! } open class PoolSystem( - val dataStorage: StorageInterface, - val user: User?, - val session: Session + val dataStorage: StorageInterface, + val user: User?, + val session: Session ) { - protected val pool by lazy { ApplicationServices.clientManager.getPool(session, user) } + protected val pool by lazy { ApplicationServices.clientManager.getPool(session, user) } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt index be4c2845..d8ff3495 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/BaseActor.kt @@ -8,24 +8,24 @@ import com.simiacryptus.jopenai.models.OpenAIModel import com.simiacryptus.jopenai.models.TextModel abstract class BaseActor( - open val prompt: String, - val name: String? = null, - val model: TextModel, - val temperature: Double = 0.3, + open val prompt: String, + val name: String? = null, + val model: TextModel, + val temperature: Double = 0.3, ) { - abstract fun respond(input: I, api: API, vararg messages: ApiModel.ChatMessage): R - open fun response(vararg input: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = - (api as ChatClient).chat( - ApiModel.ChatRequest( - messages = ArrayList(input.toList()), - temperature = temperature, - model = this.model.modelName, - ), - model = this.model - ) + abstract fun respond(input: I, api: API, vararg messages: ApiModel.ChatMessage): R + open fun response(vararg input: ApiModel.ChatMessage, model: OpenAIModel = this.model, api: API) = + (api as ChatClient).chat( + ApiModel.ChatRequest( + messages = ArrayList(input.toList()), + temperature = temperature, + model = this.model.modelName, + ), + model = this.model + ) - open fun answer(input: I, api: API): R = respond(input = input, api = api, *chatMessages(input)) + open fun answer(input: I, api: API): R = respond(input = input, api = api, *chatMessages(input)) - abstract fun chatMessages(questions: I): Array - abstract fun withModel(model: ChatModel): BaseActor + abstract fun chatMessages(questions: I): Array + abstract fun withModel(model: ChatModel): BaseActor } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt index 75db0881..84cead90 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/CodingActor.kt @@ -20,60 +20,60 @@ typealias CodeInterceptor = (String) -> String open class CodingActor( - val interpreterClass: KClass, - val symbols: Map = mapOf(), - val describer: TypeDescriber = AbbrevWhitelistTSDescriber( - "com.simiacryptus", - "com.github.simiacryptus" - ), - name: String? = interpreterClass.simpleName, - val details: String? = null, - model: TextModel = OpenAIModels.GPT4o, - val fallbackModel: ChatModel = OpenAIModels.GPT4o, - temperature: Double = 0.1, - val runtimeSymbols: Map = mapOf(), - var codeInterceptor: CodeInterceptor = { it } + val interpreterClass: KClass, + val symbols: Map = mapOf(), + val describer: TypeDescriber = AbbrevWhitelistTSDescriber( + "com.simiacryptus", + "com.github.simiacryptus" + ), + name: String? = interpreterClass.simpleName, + val details: String? = null, + model: TextModel = OpenAIModels.GPT4o, + val fallbackModel: ChatModel = OpenAIModels.GPT4o, + temperature: Double = 0.1, + val runtimeSymbols: Map = mapOf(), + var codeInterceptor: CodeInterceptor = { it } ) : BaseActor( - prompt = "", - name = name, - model = model, - temperature = temperature, + prompt = "", + name = name, + model = model, + temperature = temperature, ) { - val interpreter: Interpreter - get() = interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) - - data class CodeRequest( - val messages: List>, - val codePrefix: String = "", - val autoEvaluate: Boolean = false, - val fixIterations: Int = 1, - val fixRetries: Int = 1, - ) - - interface CodeResult { - enum class Status { - Coding, Correcting, Success, Failure - } - - val code: String - val status: Status - val result: ExecutionResult - val renderedResponse: String? + val interpreter: Interpreter + get() = interpreterClass.java.getConstructor(Map::class.java).newInstance(symbols + runtimeSymbols) + + data class CodeRequest( + val messages: List>, + val codePrefix: String = "", + val autoEvaluate: Boolean = false, + val fixIterations: Int = 1, + val fixRetries: Int = 1, + ) + + interface CodeResult { + enum class Status { + Coding, Correcting, Success, Failure } - data class ExecutionResult( - val resultValue: String, - val resultOutput: String - ) - - var evalFormat = true - override val prompt: String - get() { - val formatInstructions = - if (evalFormat) """Code should be structured as appropriately parameterized function(s) + val code: String + val status: Status + val result: ExecutionResult + val renderedResponse: String? + } + + data class ExecutionResult( + val resultValue: String, + val resultOutput: String + ) + + var evalFormat = true + override val prompt: String + get() { + val formatInstructions = + if (evalFormat) """Code should be structured as appropriately parameterized function(s) with the final line invoking the function with the appropriate request parameters.""" else "" - return if (symbols.isNotEmpty()) { - """ + return if (symbols.isNotEmpty()) { + """ You are a coding assistant allows users actions to be enacted using $language and the script context. Your role is to translate natural language instructions into code as well as interpret the results and converse with the user. Use $TT code blocks labeled with $language where appropriate. (i.e. ${TT}$language) @@ -91,7 +91,7 @@ They are already defined for you. ${details ?: ""} """.trim() - } else """ + } else """ You are a coding assistant allowing users actions to be enacted using $language and the script context. Your role is to translate natural language instructions into code as well as interpret the results and converse with the user. Use $TT code blocks labeled with $language where appropriate. (i.e. ${TT}$language) @@ -100,244 +100,244 @@ $formatInstructions ${details ?: ""} """.trim() - } + } - open val apiDescription: String - get() = this.symbols.map { (name, utilityObj) -> - val describe = this.describer.describe(utilityObj.javaClass) - log.info("Describing $name (${utilityObj.javaClass}) in ${describe.length} characters") - """ + open val apiDescription: String + get() = this.symbols.map { (name, utilityObj) -> + val describe = this.describer.describe(utilityObj.javaClass) + log.info("Describing $name (${utilityObj.javaClass}) in ${describe.length} characters") + """ $name: ${describe.indent(" ")} """.trimMargin().trim() - }.joinToString("\n") - - - val language: String by lazy { interpreter.getLanguage() } - - override fun chatMessages(questions: CodeRequest): Array { - var chatMessages = arrayOf( - ChatMessage( - role = Role.system, - content = prompt.toContentList() - ), - ) + questions.messages.map { - ChatMessage( - role = it.second, - content = it.first.toContentList() - ) - } - if (questions.codePrefix.isNotBlank()) { - chatMessages = (chatMessages.dropLast(1) + listOf( - ChatMessage(Role.assistant, "Code Prefix:\n$TT\n${questions.codePrefix}\n${TT}".toContentList()) - ) + chatMessages.last()).toTypedArray() - } - return chatMessages - + }.joinToString("\n") + + + val language: String by lazy { interpreter.getLanguage() } + + override fun chatMessages(questions: CodeRequest): Array { + var chatMessages = arrayOf( + ChatMessage( + role = Role.system, + content = prompt.toContentList() + ), + ) + questions.messages.map { + ChatMessage( + role = it.second, + content = it.first.toContentList() + ) } - - override fun respond( - input: CodeRequest, - api: API, - vararg messages: ChatMessage, - ): CodeResult { - var result = CodeResultImpl( - *messages, - input = input, - api = (api as ChatClient) + if (questions.codePrefix.isNotBlank()) { + chatMessages = (chatMessages.dropLast(1) + listOf( + ChatMessage(Role.assistant, "Code Prefix:\n$TT\n${questions.codePrefix}\n${TT}".toContentList()) + ) + chatMessages.last()).toTypedArray() + } + return chatMessages + + } + + override fun respond( + input: CodeRequest, + api: API, + vararg messages: ChatMessage, + ): CodeResult { + var result = CodeResultImpl( + *messages, + input = input, + api = (api as ChatClient) + ) + if (!input.autoEvaluate) return result + for (i in 0..input.fixIterations) try { + require(result.result.resultValue.length > -1) + return result + } catch (ex: Throwable) { + if (i == input.fixIterations) { + log.info( + "Failed to implement ${ + messages.map { it.content?.joinToString("\n") { it.text ?: "" } }.joinToString("\n") + }" ) - if (!input.autoEvaluate) return result - for (i in 0..input.fixIterations) try { - require(result.result.resultValue.length > -1) - return result - } catch (ex: Throwable) { - if (i == input.fixIterations) { - log.info( - "Failed to implement ${ - messages.map { it.content?.joinToString("\n") { it.text ?: "" } }.joinToString("\n") - }" - ) - throw ex - } - val respondWithCode = fixCommand(api, result.code, ex, *messages, model = model) - val blocks = extractTextBlocks(respondWithCode) - val renderedResponse = getRenderedResponse(blocks) - val codedInstruction = codeInterceptor(getCode(language, blocks)) - log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) - log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) - result = CodeResultImpl( - *messages, - input = input, - api = api, - givenCode = codedInstruction, - givenResponse = renderedResponse - ) - } - throw IllegalStateException() + throw ex + } + val respondWithCode = fixCommand(api, result.code, ex, *messages, model = model) + val blocks = extractTextBlocks(respondWithCode) + val renderedResponse = getRenderedResponse(blocks) + val codedInstruction = codeInterceptor(getCode(language, blocks)) + log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) + log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) + result = CodeResultImpl( + *messages, + input = input, + api = api, + givenCode = codedInstruction, + givenResponse = renderedResponse + ) } + throw IllegalStateException() + } + + open fun execute(prefix: String, code: String): ExecutionResult { + //language=HTML + log.debug("Running $code") + OutputInterceptor.clearGlobalOutput() + val result = try { + interpreter.run((prefix + "\n" + codeInterceptor(code)).sortCode()) + } catch (e: Exception) { + when { + e is FailedToImplementException -> throw e + e is ScriptException -> throw FailedToImplementException( + cause = e, + message = errorMessage(e, code), + language = language, + code = code, + prefix = prefix, + ) - open fun execute(prefix: String, code: String): ExecutionResult { - //language=HTML - log.debug("Running $code") - OutputInterceptor.clearGlobalOutput() - val result = try { - interpreter.run((prefix + "\n" + codeInterceptor(code)).sortCode()) - } catch (e: Exception) { - when { - e is FailedToImplementException -> throw e - e is ScriptException -> throw FailedToImplementException( - cause = e, - message = errorMessage(e, code), - language = language, - code = code, - prefix = prefix, - ) - - e.cause is ScriptException -> throw FailedToImplementException( - cause = e, - message = errorMessage(e.cause!! as ScriptException, code), - language = language, - code = code, - prefix = prefix, - ) + e.cause is ScriptException -> throw FailedToImplementException( + cause = e, + message = errorMessage(e.cause!! as ScriptException, code), + language = language, + code = code, + prefix = prefix, + ) - else -> throw e - } - } - log.debug("Result: $result") - //language=HTML - val executionResult = ExecutionResult(result.toString(), OutputInterceptor.getThreadOutput()) - OutputInterceptor.clearThreadOutput() - return executionResult + else -> throw e + } } - - inner class CodeResultImpl( - vararg val messages: ChatMessage, - private val input: CodeRequest, - private val api: ChatClient, - private val givenCode: String? = null, - private val givenResponse: String? = null, - ) : CodeResult { - private val implementation by lazy { - if (!givenCode.isNullOrBlank() && !givenResponse.isNullOrBlank()) (givenCode to givenResponse) else try { - implement(model) - } catch (ex: FailedToImplementException) { - if (fallbackModel != model) { - try { - implement(fallbackModel) - } catch (ex: FailedToImplementException) { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Failure - throw ex - } - } else { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Failure - throw ex - } - } + log.debug("Result: $result") + //language=HTML + val executionResult = ExecutionResult(result.toString(), OutputInterceptor.getThreadOutput()) + OutputInterceptor.clearThreadOutput() + return executionResult + } + + inner class CodeResultImpl( + vararg val messages: ChatMessage, + private val input: CodeRequest, + private val api: ChatClient, + private val givenCode: String? = null, + private val givenResponse: String? = null, + ) : CodeResult { + private val implementation by lazy { + if (!givenCode.isNullOrBlank() && !givenResponse.isNullOrBlank()) (givenCode to givenResponse) else try { + implement(model) + } catch (ex: FailedToImplementException) { + if (fallbackModel != model) { + try { + implement(fallbackModel) + } catch (ex: FailedToImplementException) { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Failure + throw ex + } + } else { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Failure + throw ex } + } + } - private var _status = CodeResult.Status.Coding - - override val status get() = _status - - override val renderedResponse: String = givenResponse ?: implementation.second - override val code: String = givenCode ?: implementation.first - - private fun implement( - model: TextModel, - ): Pair { - val request = ChatRequest(messages = ArrayList(this.messages.toList())) - for (codingAttempt in 0..input.fixRetries) { - try { - val codeBlocks = extractTextBlocks(chat(api, request, model)) - val renderedResponse = getRenderedResponse(codeBlocks) - val codedInstruction = codeInterceptor(getCode(language, codeBlocks)) - log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) - log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) - var workingCode = codedInstruction - var workingRenderedResponse = renderedResponse - for (fixAttempt in 0..input.fixIterations) { - try { - val validate = interpreter.validate((input.codePrefix + "\n" + codeInterceptor(workingCode)).sortCode()) - if (validate != null) throw validate - log.debug("Validation succeeded") - _status = CodeResult.Status.Success - return workingCode to workingRenderedResponse - } catch (ex: Throwable) { - if (fixAttempt == input.fixIterations) - throw if (ex is FailedToImplementException) ex else FailedToImplementException( - cause = ex, - message = """ + private var _status = CodeResult.Status.Coding + + override val status get() = _status + + override val renderedResponse: String = givenResponse ?: implementation.second + override val code: String = givenCode ?: implementation.first + + private fun implement( + model: TextModel, + ): Pair { + val request = ChatRequest(messages = ArrayList(this.messages.toList())) + for (codingAttempt in 0..input.fixRetries) { + try { + val codeBlocks = extractTextBlocks(chat(api, request, model)) + val renderedResponse = getRenderedResponse(codeBlocks) + val codedInstruction = codeInterceptor(getCode(language, codeBlocks)) + log.debug("Response: \n\t${renderedResponse.replace("\n", "\n\t", false)}".trimMargin()) + log.debug("New Code: \n\t${codedInstruction.replace("\n", "\n\t", false)}".trimMargin()) + var workingCode = codedInstruction + var workingRenderedResponse = renderedResponse + for (fixAttempt in 0..input.fixIterations) { + try { + val validate = interpreter.validate((input.codePrefix + "\n" + codeInterceptor(workingCode)).sortCode()) + if (validate != null) throw validate + log.debug("Validation succeeded") + _status = CodeResult.Status.Success + return workingCode to workingRenderedResponse + } catch (ex: Throwable) { + if (fixAttempt == input.fixIterations) + throw if (ex is FailedToImplementException) ex else FailedToImplementException( + cause = ex, + message = """ **ERROR** | ${TT}text ${ex.stackTraceToString()} ${TT} """.trim(), - language = language, - code = workingCode, - prefix = input.codePrefix - ) - log.debug("Validation failed - ${ex.message}") - _status = CodeResult.Status.Correcting - val respondWithCode = fixCommand(api, workingCode, ex, *messages, model = model) - val codeBlocks = extractTextBlocks(respondWithCode) - workingRenderedResponse = getRenderedResponse(codeBlocks) - workingCode = codeInterceptor(getCode(language, codeBlocks)) - log.debug( - "Response: \n\t${ - workingRenderedResponse.replace( - "\n", - "\n\t", - false - ) - }".trimMargin() - ) - log.debug("New Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin()) - } - } - } catch (ex: FailedToImplementException) { - if (codingAttempt == input.fixRetries) { - log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") - throw ex - } - log.debug("Retry failed to implement ${messages.map { it.content }.joinToString("\n")}") - _status = CodeResult.Status.Correcting - } + language = language, + code = workingCode, + prefix = input.codePrefix + ) + log.debug("Validation failed - ${ex.message}") + _status = CodeResult.Status.Correcting + val respondWithCode = fixCommand(api, workingCode, ex, *messages, model = model) + val codeBlocks = extractTextBlocks(respondWithCode) + workingRenderedResponse = getRenderedResponse(codeBlocks) + workingCode = codeInterceptor(getCode(language, codeBlocks)) + log.debug( + "Response: \n\t${ + workingRenderedResponse.replace( + "\n", + "\n\t", + false + ) + }".trimMargin() + ) + log.debug("New Code: \n\t${workingCode.replace("\n", "\n\t", false)}".trimMargin()) } - throw IllegalStateException() + } + } catch (ex: FailedToImplementException) { + if (codingAttempt == input.fixRetries) { + log.debug("Failed to implement ${messages.map { it.content }.joinToString("\n")}") + throw ex + } + log.debug("Retry failed to implement ${messages.map { it.content }.joinToString("\n")}") + _status = CodeResult.Status.Correcting } - - - private val executionResult by lazy { execute(input.codePrefix, code) } - - override val result get() = executionResult + } + throw IllegalStateException() } - private fun fixCommand( - api: ChatClient, - previousCode: String, - error: Throwable, - vararg promptMessages: ChatMessage, - model: TextModel - ): String = chat( - api = api, - request = ChatRequest( - messages = ArrayList( - promptMessages.toList() + listOf( - ChatMessage( - Role.assistant, - """ + + private val executionResult by lazy { execute(input.codePrefix, code) } + + override val result get() = executionResult + } + + private fun fixCommand( + api: ChatClient, + previousCode: String, + error: Throwable, + vararg promptMessages: ChatMessage, + model: TextModel + ): String = chat( + api = api, + request = ChatRequest( + messages = ArrayList( + promptMessages.toList() + listOf( + ChatMessage( + Role.assistant, + """ $TT${language.lowercase()} ${previousCode.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} ${TT} """.trim().toContentList() - ), - ChatMessage( - Role.system, - """ + ), + ChatMessage( + Role.system, + """ The previous code failed with the following error: $TT @@ -346,164 +346,164 @@ ${TT} Correct the code and try again. """.trim().toContentList() - ) - ) - ) - ), - model = model - ) + ) + ) + ) + ), + model = model + ) - private fun chat(api: ChatClient, request: ChatRequest, model: TextModel) = - api.chat(request.copy(model = model.modelName, temperature = temperature), model) - .choices.first().message?.content.orEmpty().trim() - - - override fun withModel(model: ChatModel): CodingActor = CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - describer = describer, - name = name, - details = details, - model = model, - fallbackModel = fallbackModel, - temperature = temperature, - runtimeSymbols = runtimeSymbols, - codeInterceptor = codeInterceptor - ) + private fun chat(api: ChatClient, request: ChatRequest, model: TextModel) = + api.chat(request.copy(model = model.modelName, temperature = temperature), model) + .choices.first().message?.content.orEmpty().trim() - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(CodingActor::class.java) - fun String.indent(indent: String = " ") = this.replace("\n", "\n$indent") + override fun withModel(model: ChatModel): CodingActor = CodingActor( + interpreterClass = interpreterClass, + symbols = symbols, + describer = describer, + name = name, + details = details, + model = model, + fallbackModel = fallbackModel, + temperature = temperature, + runtimeSymbols = runtimeSymbols, + codeInterceptor = codeInterceptor + ) - fun extractTextBlocks(response: String): List> { - val codeBlockRegex = Regex("(?s)$TT(.*?)\\n(.*?)${TT}") - val languageRegex = Regex("([a-zA-Z0-9-_]+)") + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(CodingActor::class.java) - val result = mutableListOf>() - var startIndex = 0 + fun String.indent(indent: String = " ") = this.replace("\n", "\n$indent") - val matches = codeBlockRegex.findAll(response) - if (matches.count() == 0) return listOf(Pair("text", response)) - for (match in matches) { - // Add non-code block before the current match as "text" - if (startIndex < match.range.first) { - result.add(Pair("text", response.substring(startIndex, match.range.first))) - } + fun extractTextBlocks(response: String): List> { + val codeBlockRegex = Regex("(?s)$TT(.*?)\\n(.*?)${TT}") + val languageRegex = Regex("([a-zA-Z0-9-_]+)") - // Extract language and code - val languageMatch = languageRegex.find(match.groupValues[1]) - val language = languageMatch?.groupValues?.get(0) ?: "code" - val code = match.groupValues[2] + val result = mutableListOf>() + var startIndex = 0 - // Add code block to the result - result.add(Pair(language, code)) + val matches = codeBlockRegex.findAll(response) + if (matches.count() == 0) return listOf(Pair("text", response)) + for (match in matches) { + // Add non-code block before the current match as "text" + if (startIndex < match.range.first) { + result.add(Pair("text", response.substring(startIndex, match.range.first))) + } - // Update the start index - startIndex = match.range.last + 1 - } + // Extract language and code + val languageMatch = languageRegex.find(match.groupValues[1]) + val language = languageMatch?.groupValues?.get(0) ?: "code" + val code = match.groupValues[2] - // Add any remaining non-code text after the last code block as "text" - if (startIndex < response.length) { - result.add(Pair("text", response.substring(startIndex))) - } + // Add code block to the result + result.add(Pair(language, code)) - return result - } + // Update the start index + startIndex = match.range.last + 1 + } - fun getRenderedResponse(respondWithCode: List>, defaultLanguage: String = "") = - respondWithCode.joinToString("\n") { - when (it.first) { - "code" -> "$TT$defaultLanguage\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n${TT}" - "text" -> it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }.toString() - else -> "$TT${it.first}\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n${TT}" - } - } + // Add any remaining non-code text after the last code block as "text" + if (startIndex < response.length) { + result.add(Pair("text", response.substring(startIndex))) + } - fun getCode(language: String, textSegments: List>): String { - if (textSegments.size == 1) return textSegments.joinToString("\n") { it.second } - return textSegments.joinToString("\n") { - if (it.first.lowercase() == "code" || it.first.lowercase() == language.lowercase()) { - it.second.trimMargin().trim() - } else { - "" - } - } - } + return result + } - fun String.sortCode(bodyWrapper: (String) -> String = { it }): String { - val (imports, otherCode) = this.split("\n").partition { it.trim().startsWith("import ") } - return imports.map { it.trim() }.distinct().sorted().joinToString("\n") + "\n\n" + bodyWrapper(otherCode.joinToString("\n")) + fun getRenderedResponse(respondWithCode: List>, defaultLanguage: String = "") = + respondWithCode.joinToString("\n") { + when (it.first) { + "code" -> "$TT$defaultLanguage\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n${TT}" + "text" -> it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }.toString() + else -> "$TT${it.first}\n${it.second.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n${TT}" } - - fun String.camelCase(locale: Locale = Locale.getDefault()): String { - val words = fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() } - return words.first().lowercase(locale) + words.drop(1).joinToString("") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } - } - } + } + + fun getCode(language: String, textSegments: List>): String { + if (textSegments.size == 1) return textSegments.joinToString("\n") { it.second } + return textSegments.joinToString("\n") { + if (it.first.lowercase() == "code" || it.first.lowercase() == language.lowercase()) { + it.second.trimMargin().trim() + } else { + "" } + } + } - fun String.pascalCase(locale: Locale = Locale.getDefault()): String = - fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } - } - } + fun String.sortCode(bodyWrapper: (String) -> String = { it }): String { + val (imports, otherCode) = this.split("\n").partition { it.trim().startsWith("import ") } + return imports.map { it.trim() }.distinct().sorted().joinToString("\n") + "\n\n" + bodyWrapper(otherCode.joinToString("\n")) + } - // Detect changes in the case of the first letter and prepend a space - private fun String.fromPascalCase(): String = buildString { - var lastChar = ' ' - for (c in this@fromPascalCase) { - if (c.isUpperCase() && lastChar.isLowerCase()) append(' ') - append(c) - lastChar = c - } + fun String.camelCase(locale: Locale = Locale.getDefault()): String { + val words = fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() } + return words.first().lowercase(locale) + words.drop(1).joinToString("") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } } + } + } - fun String.upperSnakeCase(locale: Locale = Locale.getDefault()): String = - fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("_") { - it.replaceFirstChar { c -> - when { - c.isLowerCase() -> c.titlecase(locale) - else -> c.toString() - } - } - }.uppercase(locale) - - fun String.imports(): List { - return this.split("\n").filter { it.trim().startsWith("import ") }.distinct().sorted() + fun String.pascalCase(locale: Locale = Locale.getDefault()): String = + fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } } + } + + // Detect changes in the case of the first letter and prepend a space + private fun String.fromPascalCase(): String = buildString { + var lastChar = ' ' + for (c in this@fromPascalCase) { + if (c.isUpperCase() && lastChar.isLowerCase()) append(' ') + append(c) + lastChar = c + } + } - fun String.stripImports(): String { - return this.split("\n").filter { !it.trim().startsWith("import ") }.joinToString("\n") + fun String.upperSnakeCase(locale: Locale = Locale.getDefault()): String = + fromPascalCase().split(" ").map { it.trim() }.filter { it.isNotEmpty() }.joinToString("_") { + it.replaceFirstChar { c -> + when { + c.isLowerCase() -> c.titlecase(locale) + else -> c.toString() + } } + }.uppercase(locale) - fun errorMessage(ex: ScriptException, code: String) = try { - """ + fun String.imports(): List { + return this.split("\n").filter { it.trim().startsWith("import ") }.distinct().sorted() + } + + fun String.stripImports(): String { + return this.split("\n").filter { !it.trim().startsWith("import ") }.joinToString("\n") + } + + fun errorMessage(ex: ScriptException, code: String) = try { + """ |${TT}text |${ex.message ?: ""} at line ${ex.lineNumber} column ${ex.columnNumber} | ${if (ex.lineNumber > 0) code.split("\n")[ex.lineNumber - 1] else ""} | ${if (ex.columnNumber > 0) " ".repeat(ex.columnNumber - 1) + "^" else ""} |${TT} """.trimMargin().trim() - } catch (_: Exception) { - ex.message ?: "" - } + } catch (_: Exception) { + ex.message ?: "" } - - class FailedToImplementException( - cause: Throwable? = null, - message: String = "Failed to implement", - val language: String? = null, - val code: String? = null, - val prefix: String? = null, - ) : RuntimeException(message, cause) + } + + class FailedToImplementException( + cause: Throwable? = null, + message: String = "Failed to implement", + val language: String? = null, + val code: String? = null, + val prefix: String? = null, + ) : RuntimeException(message, cause) } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt index 7e0669f7..45ca7693 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageActor.kt @@ -15,87 +15,87 @@ import java.net.URL import javax.imageio.ImageIO open class ImageActor( - prompt: String = "Transform the user request into an image generation prompt that the user will like", - name: String? = null, - textModel: TextModel, - val imageModel: ImageModels = ImageModels.DallE2, - temperature: Double = 0.3, - val width: Int = 1024, - val height: Int = 1024, + prompt: String = "Transform the user request into an image generation prompt that the user will like", + name: String? = null, + textModel: TextModel, + val imageModel: ImageModels = ImageModels.DallE2, + temperature: Double = 0.3, + val width: Int = 1024, + val height: Int = 1024, ) : BaseActor, ImageResponse>( - prompt = prompt, - name = name, - model = textModel, - temperature = temperature, + prompt = prompt, + name = name, + model = textModel, + temperature = temperature, ) { - override fun chatMessages(questions: List) = arrayOf( - ChatMessage( - role = ApiModel.Role.system, - content = prompt.toContentList() - ), - ) + questions.map { - ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - } + override fun chatMessages(questions: List) = arrayOf( + ChatMessage( + role = ApiModel.Role.system, + content = prompt.toContentList() + ), + ) + questions.map { + ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() + ) + } - inner class ImageResponseImpl( - override val text: String, - private val api: API - ) : ImageResponse { - private val _image: BufferedImage by lazy { render(text, api) } - override val image: BufferedImage get() = _image - } + inner class ImageResponseImpl( + override val text: String, + private val api: API + ) : ImageResponse { + private val _image: BufferedImage by lazy { render(text, api) } + override val image: BufferedImage get() = _image + } - open fun render( - text: String, - api: API, - ): BufferedImage { - val url = (api as OpenAIClient).createImage( - ImageGenerationRequest( - prompt = text, - model = imageModel.modelName, - size = "${width}x$height" - ) - ).data.first().url - return ImageIO.read(URL(url)) - } + open fun render( + text: String, + api: API, + ): BufferedImage { + val url = (api as OpenAIClient).createImage( + ImageGenerationRequest( + prompt = text, + model = imageModel.modelName, + size = "${width}x$height" + ) + ).data.first().url + return ImageIO.read(URL(url)) + } - override fun respond(input: List, api: API, vararg messages: ChatMessage): ImageResponse { - var text = response(*messages, api = api).choices.first().message?.content - ?: throw RuntimeException("No response") - while (imageModel.maxPrompt <= text.length && null != openAI) { - text = response( - *listOf( - messages.toList(), - listOf( - text.toChatMessage(), - "Please shorten the description".toChatMessage(), - ), - ).flatten().toTypedArray(), - model = imageModel, - api = api - ).choices.first().message?.content ?: throw RuntimeException("No response") - } - return ImageResponseImpl(text, api = this.openAI!!) + override fun respond(input: List, api: API, vararg messages: ChatMessage): ImageResponse { + var text = response(*messages, api = api).choices.first().message?.content + ?: throw RuntimeException("No response") + while (imageModel.maxPrompt <= text.length && null != openAI) { + text = response( + *listOf( + messages.toList(), + listOf( + text.toChatMessage(), + "Please shorten the description".toChatMessage(), + ), + ).flatten().toTypedArray(), + model = imageModel, + api = api + ).choices.first().message?.content ?: throw RuntimeException("No response") } + return ImageResponseImpl(text, api = this.openAI!!) + } - override fun withModel(model: ChatModel): ImageActor = ImageActor( - prompt = prompt, - name = name, - textModel = model, - imageModel = imageModel, - temperature = temperature, - width = width, - height = height, - ) + override fun withModel(model: ChatModel): ImageActor = ImageActor( + prompt = prompt, + name = name, + textModel = model, + imageModel = imageModel, + temperature = temperature, + width = width, + height = height, + ) - var openAI: OpenAIClient? = null - fun setImageAPI(openAI: OpenAIClient): ImageActor { - this.openAI = openAI - return this - } + var openAI: OpenAIClient? = null + fun setImageAPI(openAI: OpenAIClient): ImageActor { + this.openAI = openAI + return this + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageResponse.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageResponse.kt index 19703f0c..55346297 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageResponse.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ImageResponse.kt @@ -3,6 +3,6 @@ package com.simiacryptus.skyenet.core.actors import java.awt.image.BufferedImage interface ImageResponse { - val text: String - val image: BufferedImage + val text: String + val image: BufferedImage } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt index e2cb3719..f438b53b 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedActor.kt @@ -15,57 +15,57 @@ import org.slf4j.LoggerFactory import java.util.function.Function open class ParsedActor( - var resultClass: Class? = null, - val exampleInstance: T? = resultClass?.getConstructor()?.newInstance(), - prompt: String = "", - name: String? = resultClass?.simpleName, - model: TextModel = OpenAIModels.GPT4o, - temperature: Double = 0.3, - val parsingModel: TextModel = OpenAIModels.GPT4oMini, - val deserializerRetries: Int = 2, - open val describer: TypeDescriber = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false - }, - var parserPrompt: String? = null, + var resultClass: Class? = null, + val exampleInstance: T? = resultClass?.getConstructor()?.newInstance(), + prompt: String = "", + name: String? = resultClass?.simpleName, + model: TextModel = OpenAIModels.GPT4o, + temperature: Double = 0.3, + val parsingModel: TextModel = OpenAIModels.GPT4oMini, + val deserializerRetries: Int = 2, + open val describer: TypeDescriber = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false + }, + var parserPrompt: String? = null, ) : BaseActor, ParsedResponse>( - prompt = prompt, - name = name, - model = model, - temperature = temperature, + prompt = prompt, + name = name, + model = model, + temperature = temperature, ) { - init { - requireNotNull(resultClass) { - "Result class is required" - } + init { + requireNotNull(resultClass) { + "Result class is required" } + } - override fun chatMessages(questions: List) = arrayOf( - ApiModel.ChatMessage( - role = ApiModel.Role.system, - content = prompt.toContentList() - ), - ) + questions.map { - ApiModel.ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - } + override fun chatMessages(questions: List) = arrayOf( + ApiModel.ChatMessage( + role = ApiModel.Role.system, + content = prompt.toContentList() + ), + ) + questions.map { + ApiModel.ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() + ) + } - private inner class ParsedResponseImpl(api: API, vararg messages: ApiModel.ChatMessage) : - ParsedResponse(resultClass!!) { - override val text = - response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") - private val _obj: T by lazy { getParser(api, parserPrompt).apply(text) } - override val obj get() = _obj - } + private inner class ParsedResponseImpl(api: API, vararg messages: ApiModel.ChatMessage) : + ParsedResponse(resultClass!!) { + override val text = + response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") + private val _obj: T by lazy { getParser(api, parserPrompt).apply(text) } + override val obj get() = _obj + } - fun getParser(api: API, promptSuffix: String? = null) = Function { input -> - describer.coverMethods = false - val describe = resultClass?.let { describer.describe(it) } ?: "" - val exceptions = mutableListOf() - val prompt = """ + fun getParser(api: API, promptSuffix: String? = null) = Function { input -> + describer.coverMethods = false + val describe = resultClass?.let { describer.describe(it) } ?: "" + val exceptions = mutableListOf() + val prompt = """ |Parse the user's message into a json object described by: | |```yaml @@ -79,84 +79,84 @@ open class ParsedActor( |${promptSuffix?.let { "\n$it" } ?: ""} | """.trimMargin() - for (i in 0 until deserializerRetries) { - try { - val content = (api as ChatClient).chat( - ApiModel.ChatRequest( - messages = listOf( - ApiModel.ChatMessage(role = ApiModel.Role.system, content = prompt.toContentList()), - ApiModel.ChatMessage( - role = ApiModel.Role.user, - content = "The user message to parse:\n\n$input".toContentList() - ), - ), - temperature = temperature, - model = parsingModel.modelName, - ), - model = parsingModel, - ).choices.first().message?.content - var contentUnwrapped = content?.trim() ?: throw RuntimeException("No response") + for (i in 0 until deserializerRetries) { + try { + val content = (api as ChatClient).chat( + ApiModel.ChatRequest( + messages = listOf( + ApiModel.ChatMessage(role = ApiModel.Role.system, content = prompt.toContentList()), + ApiModel.ChatMessage( + role = ApiModel.Role.user, + content = "The user message to parse:\n\n$input".toContentList() + ), + ), + temperature = temperature, + model = parsingModel.modelName, + ), + model = parsingModel, + ).choices.first().message?.content + var contentUnwrapped = content?.trim() ?: throw RuntimeException("No response") - // If Plaintext is found before the { or ```, strip it - if (!contentUnwrapped.startsWith("{") && !contentUnwrapped.startsWith("```")) { - val start = contentUnwrapped.indexOf("{").coerceAtMost(contentUnwrapped.indexOf("```")) - val end = - contentUnwrapped.lastIndexOf("}").coerceAtLeast(contentUnwrapped.lastIndexOf("```") + 2) + 1 - if (start < end && start >= 0) contentUnwrapped = contentUnwrapped.substring(start, end) - } + // If Plaintext is found before the { or ```, strip it + if (!contentUnwrapped.startsWith("{") && !contentUnwrapped.startsWith("```")) { + val start = contentUnwrapped.indexOf("{").coerceAtMost(contentUnwrapped.indexOf("```")) + val end = + contentUnwrapped.lastIndexOf("}").coerceAtLeast(contentUnwrapped.lastIndexOf("```") + 2) + 1 + if (start < end && start >= 0) contentUnwrapped = contentUnwrapped.substring(start, end) + } - // if input is wrapped in a ```json block, remove the block - if (contentUnwrapped.startsWith("```json")) { - val endIndex = contentUnwrapped.lastIndexOf("```") - if (endIndex > 7) { - contentUnwrapped = contentUnwrapped.substring(7, endIndex) - } else { - throw RuntimeException( - "Failed to parse response: ${ - contentUnwrapped.replace( - "\n", - "\n " - ) - }" - ) - } - } + // if input is wrapped in a ```json block, remove the block + if (contentUnwrapped.startsWith("```json")) { + val endIndex = contentUnwrapped.lastIndexOf("```") + if (endIndex > 7) { + contentUnwrapped = contentUnwrapped.substring(7, endIndex) + } else { + throw RuntimeException( + "Failed to parse response: ${ + contentUnwrapped.replace( + "\n", + "\n " + ) + }" + ) + } + } - contentUnwrapped.let { - return@Function JsonUtil.fromJson( - it, resultClass - ?: throw RuntimeException("Result class undefined") - ) - } - } catch (e: Exception) { - log.info("Failed to parse response", e) - exceptions.add(e) - } + contentUnwrapped.let { + return@Function JsonUtil.fromJson( + it, resultClass + ?: throw RuntimeException("Result class undefined") + ) } - throw MultiExeption(exceptions) + } catch (e: Exception) { + log.info("Failed to parse response", e) + exceptions.add(e) + } } + throw MultiExeption(exceptions) + } - override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): ParsedResponse { - try { - return ParsedResponseImpl(api, *messages) - } catch (e: Exception) { - log.info("Failed to parse response", e) - throw e - } + override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): ParsedResponse { + try { + return ParsedResponseImpl(api, *messages) + } catch (e: Exception) { + log.info("Failed to parse response", e) + throw e } + } - override fun withModel(model: ChatModel): ParsedActor = ParsedActor( - resultClass = resultClass, - prompt = prompt, - name = name, - model = model, - temperature = temperature, - parsingModel = parsingModel, - ) + override fun withModel(model: ChatModel): ParsedActor = ParsedActor( + resultClass = resultClass, + prompt = prompt, + name = name, + model = model, + temperature = temperature, + parsingModel = parsingModel, + ) - companion object { - private val log = LoggerFactory.getLogger(ParsedActor::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(ParsedActor::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedResponse.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedResponse.kt index d943132a..0fdf50e0 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedResponse.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/ParsedResponse.kt @@ -1,24 +1,24 @@ package com.simiacryptus.skyenet.core.actors abstract class ParsedResponse(val clazz: Class) { - abstract val text: String - abstract val obj: T - override fun toString() = text - open fun map(cls: Class, fn: (T) -> V): ParsedResponse = MappedResponse(cls, this.clazz, fn, this) + abstract val text: String + abstract val obj: T + override fun toString() = text + open fun map(cls: Class, fn: (T) -> V): ParsedResponse = MappedResponse(cls, this.clazz, fn, this) } class MappedResponse( - clazz: Class, - private val cls: Class, - private val fn: (F) -> T, - private val inner: ParsedResponse + clazz: Class, + private val cls: Class, + private val fn: (F) -> T, + private val inner: ParsedResponse ) : ParsedResponse(clazz) { - override val text: String - get() = inner.text - override val obj: T - get() = fn(inner.obj) + override val text: String + get() = inner.text + override val obj: T + get() = fn(inner.obj) - override fun map(cls: Class, fn: (T) -> V): ParsedResponse { - return MappedResponse(cls, this.clazz, fn, this) - } + override fun map(cls: Class, fn: (T) -> V): ParsedResponse { + return MappedResponse(cls, this.clazz, fn, this) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt index 606cf0dc..10af8b83 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SimpleActor.kt @@ -7,36 +7,36 @@ import com.simiacryptus.jopenai.models.TextModel import com.simiacryptus.jopenai.util.ClientUtil.toContentList open class SimpleActor( - prompt: String, - name: String? = null, - model: TextModel, - temperature: Double = 0.3, + prompt: String, + name: String? = null, + model: TextModel, + temperature: Double = 0.3, ) : BaseActor, String>( - prompt = prompt, - name = name, - model = model, - temperature = temperature, + prompt = prompt, + name = name, + model = model, + temperature = temperature, ) { - override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): String = - response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") - - override fun chatMessages(questions: List) = arrayOf( - ApiModel.ChatMessage( - role = ApiModel.Role.system, - content = prompt.toContentList() - ), - ) + questions.map { - ApiModel.ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - } + override fun respond(input: List, api: API, vararg messages: ApiModel.ChatMessage): String = + response(*messages, api = api).choices.first().message?.content ?: throw RuntimeException("No response") - override fun withModel(model: ChatModel): SimpleActor = SimpleActor( - prompt = prompt, - name = name, - model = model, - temperature = temperature, + override fun chatMessages(questions: List) = arrayOf( + ApiModel.ChatMessage( + role = ApiModel.Role.system, + content = prompt.toContentList() + ), + ) + questions.map { + ApiModel.ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() ) + } + + override fun withModel(model: ChatModel): SimpleActor = SimpleActor( + prompt = prompt, + name = name, + model = model, + temperature = temperature, + ) } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SpeechResponse.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SpeechResponse.kt index b555568d..b3e7990f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SpeechResponse.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/SpeechResponse.kt @@ -1,5 +1,5 @@ package com.simiacryptus.skyenet.core.actors interface SpeechResponse { - val mp3data: ByteArray? + val mp3data: ByteArray? } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt index fd412bbf..878e44b2 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/actors/TextToSpeechActor.kt @@ -9,57 +9,57 @@ import com.simiacryptus.jopenai.models.ChatModel import com.simiacryptus.jopenai.util.ClientUtil.toContentList open class TextToSpeechActor( - name: String? = null, - val audioModel: AudioModels = AudioModels.TTS_HD, - val voice: String = "alloy", - val speed: Double = 1.0, - val models: ChatModel, + name: String? = null, + val audioModel: AudioModels = AudioModels.TTS_HD, + val voice: String = "alloy", + val speed: Double = 1.0, + val models: ChatModel, ) : BaseActor, SpeechResponse>( - prompt = "", - name = name, - model = models, + prompt = "", + name = name, + model = models, ) { - var openAI: OpenAIClient? = null - fun setOpenAI(openAI: OpenAIClient): TextToSpeechActor { - this.openAI = openAI - return this - } + var openAI: OpenAIClient? = null + fun setOpenAI(openAI: OpenAIClient): TextToSpeechActor { + this.openAI = openAI + return this + } - override fun chatMessages(questions: List) = questions.map { - ChatMessage( - role = ApiModel.Role.user, - content = it.toContentList() - ) - }.toTypedArray() + override fun chatMessages(questions: List) = questions.map { + ChatMessage( + role = ApiModel.Role.user, + content = it.toContentList() + ) + }.toTypedArray() - inner class SpeechResponseImpl( - val text: String, - private val api: API - ) : SpeechResponse { - private val _image: ByteArray? by lazy { render(text, api) } - override val mp3data: ByteArray? get() = _image - } + inner class SpeechResponseImpl( + val text: String, + private val api: API + ) : SpeechResponse { + private val _image: ByteArray? by lazy { render(text, api) } + override val mp3data: ByteArray? get() = _image + } - open fun render( - text: String, - api: API, - ): ByteArray = (api as OpenAIClient).createSpeech( - ApiModel.SpeechRequest( - input = text, - model = audioModel.modelName, - voice = voice, - speed = speed, - ) - ) ?: throw RuntimeException("No response") + open fun render( + text: String, + api: API, + ): ByteArray = (api as OpenAIClient).createSpeech( + ApiModel.SpeechRequest( + input = text, + model = audioModel.modelName, + voice = voice, + speed = speed, + ) + ) ?: throw RuntimeException("No response") - override fun respond(input: List, api: API, vararg messages: ChatMessage) = - SpeechResponseImpl( - messages.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, - api = this.openAI ?: throw RuntimeException("OpenAI client not set") - ) + override fun respond(input: List, api: API, vararg messages: ChatMessage) = + SpeechResponseImpl( + messages.joinToString("\n") { it.content?.joinToString("\n") { it.text ?: "" } ?: "" }, + api = this.openAI ?: throw RuntimeException("OpenAI client not set") + ) - override fun withModel(model: ChatModel) = TextToSpeechActor(name, audioModel, voice, speed, model) - .also { it.openAI = this.openAI } + override fun withModel(model: ChatModel) = TextToSpeechActor(name, audioModel, voice, speed, model) + .also { it.openAI = this.openAI } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt index 5bd95bbc..4f0da390 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ApplicationServices.kt @@ -17,53 +17,53 @@ import java.io.File import java.util.concurrent.ThreadPoolExecutor object ApplicationServices { - var authorizationManager: AuthorizationInterface = AuthorizationManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var userSettingsManager: UserSettingsInterface = UserSettingsManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var authenticationManager: AuthenticationInterface = AuthenticationManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var dataStorageFactory: (File) -> StorageInterface = { DataStorage(it) } - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var metadataStorageFactory: (File) -> MetadataStorageInterface = { HSQLMetadataStorage(it) } - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var clientManager: ClientManager = ClientManager() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } + var authorizationManager: AuthorizationInterface = AuthorizationManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var userSettingsManager: UserSettingsInterface = UserSettingsManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var authenticationManager: AuthenticationInterface = AuthenticationManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var dataStorageFactory: (File) -> StorageInterface = { DataStorage(it) } + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var metadataStorageFactory: (File) -> MetadataStorageInterface = { HSQLMetadataStorage(it) } + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var clientManager: ClientManager = ClientManager() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } - var cloud: CloudPlatformInterface? = AwsPlatform.get() - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } + var cloud: CloudPlatformInterface? = AwsPlatform.get() + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } - var seleniumFactory: ((ThreadPoolExecutor, Array?) -> Selenium)? = null - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var usageManager: UsageInterface = HSQLUsageManager(File(dataStorageRoot, "usage")) - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } + var seleniumFactory: ((ThreadPoolExecutor, Array?) -> Selenium)? = null + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var usageManager: UsageInterface = HSQLUsageManager(File(dataStorageRoot, "usage")) + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt index 0a68078e..188440b9 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/AwsPlatform.kt @@ -16,99 +16,99 @@ import java.util.* open class AwsPlatform( - private val bucket: String = System.getProperty("share_bucket", "share.simiacrypt.us"), - override val shareBase: String = System.getProperty("share_base", "https://" + bucket), - private val region: Region? = Region.US_EAST_1, - private val profileName: String = "default", + private val bucket: String = System.getProperty("share_bucket", "share.simiacrypt.us"), + override val shareBase: String = System.getProperty("share_base", "https://" + bucket), + private val region: Region? = Region.US_EAST_1, + private val profileName: String = "default", ) : CloudPlatformInterface { - open val credentialsProvider: ProfileCredentialsProvider? = ProfileCredentialsProvider.create(profileName) - private val log = LoggerFactory.getLogger(AwsPlatform::class.java) + open val credentialsProvider: ProfileCredentialsProvider? = ProfileCredentialsProvider.create(profileName) + private val log = LoggerFactory.getLogger(AwsPlatform::class.java) - protected open val kmsClient: KmsClient by lazy { - log.debug("Initializing KMS client for region: {}", Region.US_EAST_1) - var clientBuilder = KmsClient.builder().region(Region.US_EAST_1) - if(null != credentialsProvider) clientBuilder = clientBuilder.credentialsProvider(credentialsProvider) - clientBuilder.build() - } + protected open val kmsClient: KmsClient by lazy { + log.debug("Initializing KMS client for region: {}", Region.US_EAST_1) + var clientBuilder = KmsClient.builder().region(Region.US_EAST_1) + if (null != credentialsProvider) clientBuilder = clientBuilder.credentialsProvider(credentialsProvider) + clientBuilder.build() + } - protected open val s3Client: S3Client by lazy { - log.debug("Initializing S3 client for region: {}", region) - var clientBuilder = S3Client.builder() - if(null != credentialsProvider) clientBuilder = clientBuilder.credentialsProvider(credentialsProvider) - clientBuilder = clientBuilder.region(region) - clientBuilder.build() - } + protected open val s3Client: S3Client by lazy { + log.debug("Initializing S3 client for region: {}", region) + var clientBuilder = S3Client.builder() + if (null != credentialsProvider) clientBuilder = clientBuilder.credentialsProvider(credentialsProvider) + clientBuilder = clientBuilder.region(region) + clientBuilder.build() + } - override fun upload( - path: String, - contentType: String, - bytes: ByteArray - ): String { - log.info("Uploading {} bytes to S3 path: {}", bytes.size, path) - s3Client.putObject( - PutObjectRequest.builder() - .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) - .contentType(contentType) - .build(), - RequestBody.fromBytes(bytes) - ) - log.debug("Upload completed successfully") - return "$shareBase/$path" - } + override fun upload( + path: String, + contentType: String, + bytes: ByteArray + ): String { + log.info("Uploading {} bytes to S3 path: {}", bytes.size, path) + s3Client.putObject( + PutObjectRequest.builder() + .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) + .contentType(contentType) + .build(), + RequestBody.fromBytes(bytes) + ) + log.debug("Upload completed successfully") + return "$shareBase/$path" + } - override fun upload( - path: String, - contentType: String, - request: String - ): String { - log.info("Uploading string content to S3 path: {}", path) - s3Client.putObject( - PutObjectRequest.builder() - .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) - .contentType(contentType) - .build(), - RequestBody.fromString(request) - ) - log.debug("Upload completed successfully") - return "$shareBase/$path" - } + override fun upload( + path: String, + contentType: String, + request: String + ): String { + log.info("Uploading string content to S3 path: {}", path) + s3Client.putObject( + PutObjectRequest.builder() + .bucket(bucket).key(path.replace("/{2,}".toRegex(), "/").removePrefix("/")) + .contentType(contentType) + .build(), + RequestBody.fromString(request) + ) + log.debug("Upload completed successfully") + return "$shareBase/$path" + } - override fun encrypt(fileBytes: ByteArray, keyId: String): String? { - log.info("Encrypting {} bytes using KMS key: {}", fileBytes.size, keyId) - val encryptedData = Base64.getEncoder().encodeToString( - kmsClient.encrypt( - EncryptRequest.builder() - .keyId(keyId) - .plaintext(SdkBytes.fromByteArray(fileBytes)) - .build() - ).ciphertextBlob().asByteArray() - ) - log.debug("Encryption completed successfully") - return encryptedData - } + override fun encrypt(fileBytes: ByteArray, keyId: String): String? { + log.info("Encrypting {} bytes using KMS key: {}", fileBytes.size, keyId) + val encryptedData = Base64.getEncoder().encodeToString( + kmsClient.encrypt( + EncryptRequest.builder() + .keyId(keyId) + .plaintext(SdkBytes.fromByteArray(fileBytes)) + .build() + ).ciphertextBlob().asByteArray() + ) + log.debug("Encryption completed successfully") + return encryptedData + } - override fun decrypt(encryptedData: ByteArray): String { - log.info("Decrypting {} bytes of data", encryptedData.size) - val decryptedData = String( - kmsClient.decrypt( - DecryptRequest.builder() - .ciphertextBlob(SdkBytes.fromByteArray(Base64.getDecoder().decode(encryptedData))) - .build() - ).plaintext().asByteArray(), StandardCharsets.UTF_8 - ) - log.debug("Decryption completed successfully") - return decryptedData - } + override fun decrypt(encryptedData: ByteArray): String { + log.info("Decrypting {} bytes of data", encryptedData.size) + val decryptedData = String( + kmsClient.decrypt( + DecryptRequest.builder() + .ciphertextBlob(SdkBytes.fromByteArray(Base64.getDecoder().decode(encryptedData))) + .build() + ).plaintext().asByteArray(), StandardCharsets.UTF_8 + ) + log.debug("Decryption completed successfully") + return decryptedData + } - companion object { - val log = LoggerFactory.getLogger(AwsPlatform::class.java) - fun get() = try { - log.info("Initializing AwsPlatform") - AwsPlatform() - } catch (e: Throwable) { - log.warn("Error initializing AWS platform", e) - null - } + companion object { + val log = LoggerFactory.getLogger(AwsPlatform::class.java) + fun get() = try { + log.info("Initializing AwsPlatform") + AwsPlatform() + } catch (e: Throwable) { + log.warn("Error initializing AWS platform", e) + null } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt index 4ecafe60..cee771a6 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/ClientManager.kt @@ -18,169 +18,169 @@ import java.util.concurrent.* open class ClientManager { - private data class SessionKey(val session: Session, val user: User?) + private data class SessionKey(val session: Session, val user: User?) - private val chatCache = mutableMapOf() - fun getChatClient( - session: Session, - user: User?, - ): ChatClient { - log.debug("Fetching client for session: {}, user: {}", session, user) - val key = SessionKey(session, user) - return chatCache.getOrPut(key) { createChatClient(session, user)!! } - } + private val chatCache = mutableMapOf() + fun getChatClient( + session: Session, + user: User?, + ): ChatClient { + log.debug("Fetching client for session: {}, user: {}", session, user) + val key = SessionKey(session, user) + return chatCache.getOrPut(key) { createChatClient(session, user)!! } + } - private val openAICache = mutableMapOf() - fun getOpenAIClient( - session: Session, - user: User?, - ): OpenAIClient { - log.debug("Fetching client for session: {}, user: {}", session, user) - val key = SessionKey(session, user) - return openAICache.getOrPut(key) { createOpenAIClient(session, user)!! } - } + private val openAICache = mutableMapOf() + fun getOpenAIClient( + session: Session, + user: User?, + ): OpenAIClient { + log.debug("Fetching client for session: {}, user: {}", session, user) + val key = SessionKey(session, user) + return openAICache.getOrPut(key) { createOpenAIClient(session, user)!! } + } - private val poolCache = mutableMapOf() - protected open fun createPool(session: Session, user: User?) = - ThreadPoolExecutor( - 0, Integer.MAX_VALUE, - 500, TimeUnit.MILLISECONDS, - ArrayBlockingQueue(1), - RecordingThreadFactory(session, user) - ) + private val poolCache = mutableMapOf() + protected open fun createPool(session: Session, user: User?) = + ThreadPoolExecutor( + 8, Integer.MAX_VALUE, + 500, TimeUnit.MILLISECONDS, + SynchronousQueue(), + RecordingThreadFactory(session, user) + ) - private val scheduledPoolCache = mutableMapOf() - protected open fun createScheduledPool(session: Session, user: User?, dataStorage: StorageInterface?) = - MoreExecutors.listeningDecorator(ScheduledThreadPoolExecutor(1)) + private val scheduledPoolCache = mutableMapOf() + protected open fun createScheduledPool(session: Session, user: User?, dataStorage: StorageInterface?) = + MoreExecutors.listeningDecorator(ScheduledThreadPoolExecutor(1)) - fun getPool( - session: Session, - user: User?, - ): ThreadPoolExecutor { - log.debug("Fetching thread pool for session: {}, user: {}", session, user) - val key = SessionKey(session, user) - return poolCache.getOrPut(key) { - createPool(session, user) - } + fun getPool( + session: Session, + user: User?, + ): ThreadPoolExecutor { + log.debug("Fetching thread pool for session: {}, user: {}", session, user) + val key = SessionKey(session, user) + return poolCache.getOrPut(key) { + createPool(session, user) } + } - fun getScheduledPool( - session: Session, - user: User?, - dataStorage: StorageInterface?, - ): ListeningScheduledExecutorService { - log.debug("Fetching scheduled pool for session: {}, user: {}", session, user) - val key = SessionKey(session, user) - return scheduledPoolCache.getOrPut(key) { - createScheduledPool(session, user, dataStorage) - } + fun getScheduledPool( + session: Session, + user: User?, + dataStorage: StorageInterface?, + ): ListeningScheduledExecutorService { + log.debug("Fetching scheduled pool for session: {}, user: {}", session, user) + val key = SessionKey(session, user) + return scheduledPoolCache.getOrPut(key) { + createScheduledPool(session, user, dataStorage) } + } - inner class RecordingThreadFactory( - val session: Session, - val user: User? - ) : ThreadFactory { - private val inner = ThreadFactoryBuilder().setNameFormat("Session $session; User $user; #%d").build() - val threads = mutableSetOf() - override fun newThread(r: Runnable): Thread { - log.debug("Creating new thread for session: {}, user: {}", session, user) - inner.newThread(r).also { - threads.add(it) - return it - } - } + inner class RecordingThreadFactory( + val session: Session, + val user: User? + ) : ThreadFactory { + private val inner = ThreadFactoryBuilder().setNameFormat("Session $session; User $user; #%d").build() + val threads = mutableSetOf() + override fun newThread(r: Runnable): Thread { + log.debug("Creating new thread for session: {}, user: {}", session, user) + inner.newThread(r).also { + threads.add(it) + return it + } } + } - protected open fun createChatClient( - session: Session, - user: User?, - ): ChatClient? { - log.debug("Creating client for session: {}, user: {}", session, user) - val sessionDir = dataStorageFactory(dataStorageRoot).getDataDir(user, session).apply { mkdirs() } - if (user != null) { - val userSettings = userSettingsManager.getUserSettings(user) - val userApi = - if (userSettings.apiKeys.isNotEmpty()) { - /* - MonitoredClient( - key = userSettings.apiKeys, - apiBase = userSettings.apiBase, - logfile = sessionDir.resolve("openai.log"), - session = session, - user = user, - workPool = getPool(session, user), - )*/ - ChatClient( - key = userSettings.apiKeys, - apiBase = userSettings.apiBase, - workPool = getPool(session, user), - ).apply { - this.session = session - this.user = user - logStreams += sessionDir.resolve("openai.log").outputStream().buffered() - } - } else null - if (userApi != null) return userApi - } - val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( - null, user, OperationType.GlobalKey - ) - if (!canUseGlobalKey) throw RuntimeException("No API key") - return (if (ClientUtil.keyMap.isNotEmpty()) { - ChatClient( - key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, - workPool = getPool(session, user), - ).apply { - this.session = session - this.user = user - logStreams += sessionDir.resolve("openai.log").outputStream().buffered() - } - } else { - null - })!! + protected open fun createChatClient( + session: Session, + user: User?, + ): ChatClient? { + log.debug("Creating client for session: {}, user: {}", session, user) + val sessionDir = dataStorageFactory(dataStorageRoot).getDataDir(user, session).apply { mkdirs() } + if (user != null) { + val userSettings = userSettingsManager.getUserSettings(user) + val userApi = + if (userSettings.apiKeys.isNotEmpty()) { + /* + MonitoredClient( + key = userSettings.apiKeys, + apiBase = userSettings.apiBase, + logfile = sessionDir.resolve("openai.log"), + session = session, + user = user, + workPool = getPool(session, user), + )*/ + ChatClient( + key = userSettings.apiKeys, + apiBase = userSettings.apiBase, + workPool = getPool(session, user), + ).apply { + this.session = session + this.user = user + logStreams += sessionDir.resolve("openai.log").outputStream().buffered() + } + } else null + if (userApi != null) return userApi } + val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( + null, user, OperationType.GlobalKey + ) + if (!canUseGlobalKey) throw RuntimeException("No API key") + return (if (ClientUtil.keyMap.isNotEmpty()) { + ChatClient( + key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, + workPool = getPool(session, user), + ).apply { + this.session = session + this.user = user + logStreams += sessionDir.resolve("openai.log").outputStream().buffered() + } + } else { + null + })!! + } - protected open fun createOpenAIClient( - session: Session, - user: User?, - ): OpenAIClient? { - log.debug("Creating client for session: {}, user: {}", session, user) - val sessionDir = dataStorageFactory(dataStorageRoot).getDataDir(user, session).apply { mkdirs() } - if (user != null) { - val userSettings = userSettingsManager.getUserSettings(user) - val userApi = - if (userSettings.apiKeys.isNotEmpty()) { - OpenAIClient( - key = userSettings.apiKeys, - apiBase = userSettings.apiBase, - workPool = getPool(session, user), - ).apply { - this.session = session - this.user = user - logStreams += sessionDir.resolve("openai.log").outputStream().buffered() - } - } else null - if (userApi != null) return userApi - } - val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( - null, user, OperationType.GlobalKey - ) - if (!canUseGlobalKey) throw RuntimeException("No API key") - return (if (ClientUtil.keyMap.isNotEmpty()) { - OpenAIClient( - key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, - workPool = getPool(session, user), - ).apply { - this.session = session - this.user = user - logStreams += sessionDir.resolve("openai.log").outputStream().buffered() - } - } else { - null - })!! + protected open fun createOpenAIClient( + session: Session, + user: User?, + ): OpenAIClient? { + log.debug("Creating client for session: {}, user: {}", session, user) + val sessionDir = dataStorageFactory(dataStorageRoot).getDataDir(user, session).apply { mkdirs() } + if (user != null) { + val userSettings = userSettingsManager.getUserSettings(user) + val userApi = + if (userSettings.apiKeys.isNotEmpty()) { + OpenAIClient( + key = userSettings.apiKeys, + apiBase = userSettings.apiBase, + workPool = getPool(session, user), + ).apply { + this.session = session + this.user = user + logStreams += sessionDir.resolve("openai.log").outputStream().buffered() + } + } else null + if (userApi != null) return userApi } + val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized( + null, user, OperationType.GlobalKey + ) + if (!canUseGlobalKey) throw RuntimeException("No API key") + return (if (ClientUtil.keyMap.isNotEmpty()) { + OpenAIClient( + key = ClientUtil.keyMap.mapKeys { APIProvider.valueOf(it.key) }, + workPool = getPool(session, user), + ).apply { + this.session = session + this.user = user + logStreams += sessionDir.resolve("openai.log").outputStream().buffered() + } + } else { + null + })!! + } - companion object { - private val log = LoggerFactory.getLogger(ClientManager::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(ClientManager::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/Session.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/Session.kt index 21641dd2..b5e68b95 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/Session.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/Session.kt @@ -7,52 +7,52 @@ import java.util.Base64 import kotlin.random.Random data class Session( - val sessionId: String + val sessionId: String ) { - init { - validateSessionId() + init { + validateSessionId() + } + + override fun toString() = sessionId + fun isGlobal(): Boolean = sessionId.startsWith("G-") + + companion object { + fun long64() = Base64.getEncoder().encodeToString(ByteBuffer.allocate(8).putLong(Random.Default.nextLong()).array()) + .toString().replace("=", "").replace("/", ".").replace("+", "-") + + fun validateSessionId(session: Session) { + session.validateSessionId() + } + + fun newGlobalID(): Session { + val yyyyMMdd = LocalDate.now().toString().replace("-", "") + return Session("G-$yyyyMMdd-${id2()}") + } + + fun newUserID(): Session { + val yyyyMMdd = LocalDate.now().toString().replace("-", "") + return Session("U-$yyyyMMdd-${id2()}") } - override fun toString() = sessionId - fun isGlobal(): Boolean = sessionId.startsWith("G-") - - companion object { - fun long64() = Base64.getEncoder().encodeToString(ByteBuffer.allocate(8).putLong(Random.Default.nextLong()).array()) - .toString().replace("=", "").replace("/", ".").replace("+", "-") - - fun validateSessionId(session: Session) { - session.validateSessionId() - } - - fun newGlobalID(): Session { - val yyyyMMdd = LocalDate.now().toString().replace("-", "") - return Session("G-$yyyyMMdd-${id2()}") - } - - fun newUserID(): Session { - val yyyyMMdd = LocalDate.now().toString().replace("-", "") - return Session("U-$yyyyMMdd-${id2()}") - } - - private fun id2() = long64().filter { - when (it) { - in 'a'..'z' -> true - in 'A'..'Z' -> true - in '0'..'9' -> true - else -> false - } - }.take(4) - - fun parseSessionID(sessionID: String): Session { - val session = Session(sessionID) - session.validateSessionId() - return session - } + private fun id2() = long64().filter { + when (it) { + in 'a'..'z' -> true + in 'A'..'Z' -> true + in '0'..'9' -> true + else -> false + } + }.take(4) + + fun parseSessionID(sessionID: String): Session { + val session = Session(sessionID) + session.validateSessionId() + return session } + } - private fun validateSessionId() { - if (!sessionId.matches("""([GU]-)?\d{8}-[\w+-.]{4}""".toRegex())) { - throw IllegalArgumentException("Invalid session ID: $this") - } + private fun validateSessionId() { + if (!sessionId.matches("""([GU]-)?\d{8}-[\w+-.]{4}""".toRegex())) { + throw IllegalArgumentException("Invalid session ID: $this") } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthenticationManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthenticationManager.kt index 0c2da530..f1cd8dae 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthenticationManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthenticationManager.kt @@ -5,18 +5,18 @@ import com.simiacryptus.skyenet.core.platform.model.User open class AuthenticationManager : AuthenticationInterface { - private val users = HashMap() + private val users = HashMap() - override fun getUser(accessToken: String?) = if (null == accessToken) null else users[accessToken] + override fun getUser(accessToken: String?) = if (null == accessToken) null else users[accessToken] - override fun putUser(accessToken: String, user: User): User { - users[accessToken] = user - return user - } + override fun putUser(accessToken: String, user: User): User { + users[accessToken] = user + return user + } - override fun logout(accessToken: String, user: User) { - require(users[accessToken] == user) { "Invalid user" } - users.remove(accessToken) - } + override fun logout(accessToken: String, user: User) { + require(users[accessToken] == user) { "Invalid user" } + users.remove(accessToken) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthorizationManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthorizationManager.kt index 85361131..ab094371 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthorizationManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/AuthorizationManager.kt @@ -6,76 +6,80 @@ import java.util.* open class AuthorizationManager : AuthorizationInterface { - override fun isAuthorized( - applicationClass: Class<*>?, - user: User?, - operationType: AuthorizationInterface.OperationType, - ) = try { - log.debug("Checking authorization for user: {}, operation: {}, application: {}", user, operationType, applicationClass) - if (isUserAuthorized("/permissions/${operationType.name.lowercase(Locale.getDefault())}.txt", user)) { - log.info("User {} authorized for {} globally", user, operationType) - true - } else if (null != applicationClass) { - val packagePath = applicationClass.`package`.name.replace('.', '/') - val opName = operationType.name.lowercase(Locale.getDefault()) - log.debug("Checking application-specific authorization at path: /permissions/{}/{}.txt", packagePath, opName) - if (isUserAuthorized("/permissions/$packagePath/$opName.txt", user)) { - log.info("User {} authorized for {} on {}", user, operationType, applicationClass) - true - } else { - log.warn("User {} not authorized for {} on {}", user, operationType, applicationClass) - false - } - } else { - log.warn("User {} not authorized for {} globally", user, operationType) - false - } - } catch (e: Exception) { - log.error("Error checking authorization", e) + override fun isAuthorized( + applicationClass: Class<*>?, + user: User?, + operationType: AuthorizationInterface.OperationType, + ) = try { + log.debug("Checking authorization for user: {}, operation: {}, application: {}", user, operationType, applicationClass) + if (isUserAuthorized("/permissions/${operationType.name.lowercase(Locale.getDefault())}.txt", user)) { + log.info("User {} authorized for {} globally", user, operationType) + true + } else if (null != applicationClass) { + val packagePath = applicationClass.`package`.name.replace('.', '/') + val opName = operationType.name.lowercase(Locale.getDefault()) + log.debug("Checking application-specific authorization at path: /permissions/{}/{}.txt", packagePath, opName) + if (isUserAuthorized("/permissions/$packagePath/$opName.txt", user)) { + log.info("User {} authorized for {} on {}", user, operationType, applicationClass) + true + } else { + log.warn("User {} not authorized for {} on {}", user, operationType, applicationClass) false + } + } else { + log.warn("User {} not authorized for {} globally", user, operationType) + false } + } catch (e: Exception) { + log.error("Error checking authorization", e) + false + } - private fun isUserAuthorized(permissionPath: String, user: User?): Boolean { - log.debug("Checking user authorization at path: {}", permissionPath) - return javaClass.getResourceAsStream(permissionPath)?.use { stream -> - val lines = stream.bufferedReader().readLines() - log.trace("Permission file contents: {}", lines) - lines.any { line -> - matches(user, line) - } - } ?: run { - log.warn("Permission file not found: {}", permissionPath) - false - } + private fun isUserAuthorized(permissionPath: String, user: User?): Boolean { + log.debug("Checking user authorization at path: {}", permissionPath) + return javaClass.getResourceAsStream(permissionPath)?.use { stream -> + val lines = stream.bufferedReader().readLines() + log.trace("Permission file contents: {}", lines) + lines.any { line -> + matches(user, line) + } + } ?: run { + log.warn("Permission file not found: {}", permissionPath) + false } + } - open fun matches(user: User?, line: String): Boolean { - log.trace("Matching user {} against line: {}", user, line) - return when { - line.equals(user?.email, ignoreCase = true) -> { - log.debug("Exact match found for user: {}", user) - true - } - line.startsWith("@") && user?.email?.endsWith(line.substring(1)) == true -> { - log.debug("Domain match found for user: {}", user) - true - } - line == "." && user != null -> { - log.debug("Any authenticated user match for: {}", user) - true - } - line == "*" -> { - log.debug("Any user (including anonymous) match") - true - } - else -> { - log.trace("No match found for user: {} and line: {}", user, line) - false - } - } - } + open fun matches(user: User?, line: String): Boolean { + log.trace("Matching user {} against line: {}", user, line) + return when { + line.equals(user?.email, ignoreCase = true) -> { + log.debug("Exact match found for user: {}", user) + true + } + + line.startsWith("@") && user?.email?.endsWith(line.substring(1)) == true -> { + log.debug("Domain match found for user: {}", user) + true + } + + line == "." && user != null -> { + log.debug("Any authenticated user match for: {}", user) + true + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(AuthorizationManager::class.java) + line == "*" -> { + log.debug("Any user (including anonymous) match") + true + } + + else -> { + log.trace("No match found for user: {} and line: {}", user, line) + false + } } + } + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(AuthorizationManager::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt index b533f7f5..bc7c56c3 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/DataStorage.kt @@ -9,187 +9,187 @@ import java.io.File import java.util.* open class DataStorage( - private val dataDir: File, + private val dataDir: File, ) : StorageInterface { - init { - log.debug("Data directory: ${dataDir.absolutePath}", RuntimeException()) + init { + log.debug("Data directory: ${dataDir.absolutePath}", RuntimeException()) + } + + override fun getMessages( + user: User?, + session: Session + ): LinkedHashMap { + Session.validateSessionId(session) + log.debug("Fetching messages for session: ${session.sessionId}, user: ${user?.email}") + val messageDir = + getDataDir(user, session).resolve("messages/") + .apply { mkdirs() } + val messages = LinkedHashMap() + getMessageIds(user, session).forEach { messageId -> + val file = File(messageDir, "$messageId.json") + if (file.exists()) { + val message = JsonUtil.objectMapper().readValue(file, String::class.java) + messages[messageId] = message + } } - - override fun getMessages( - user: User?, - session: Session - ): LinkedHashMap { - Session.validateSessionId(session) - log.debug("Fetching messages for session: ${session.sessionId}, user: ${user?.email}") - val messageDir = - getDataDir(user, session).resolve("messages/") - .apply { mkdirs() } - val messages = LinkedHashMap() - getMessageIds(user, session).forEach { messageId -> - val file = File(messageDir, "$messageId.json") - if (file.exists()) { - val message = JsonUtil.objectMapper().readValue(file, String::class.java) - messages[messageId] = message - } + log.debug("Loaded ${messages.size} messages for session: ${session.sessionId}") + return messages + } + + override fun getSessionDir( + user: User?, + session: Session + ) = if (sessionPaths.containsKey(session)) { + sessionPaths[session]!! + } else { + getDataDir(user, session).apply { mkdirs() } + } + + override fun getDataDir( + user: User?, + session: Session + ): File { + Session.validateSessionId(session) + log.debug("Getting data directory for session: ${session.sessionId}, user: ${user?.email}") + val parts = session.sessionId.split("-") + return when (parts.size) { + 3 -> { + val root = when { + parts[0] == "G" -> dataDir.resolve("global") + parts[0] == "U" -> dataDir.resolve("user-sessions/$user") + else -> throw IllegalArgumentException("Invalid session ID: $session") } - log.debug("Loaded ${messages.size} messages for session: ${session.sessionId}") - return messages + val dateDir = File(root, parts[1]) + val sessionDir = File(dateDir, parts[2]) + log.debug("Session directory for session: ${session.sessionId} is ${sessionDir.absolutePath}") + sessionDir + } + + 2 -> { + val dateDir = dataDir.resolve("global").resolve(parts[0]) + val sessionDir = dateDir.resolve(parts[1]) + log.debug("Session directory for session: ${session.sessionId} is ${sessionDir.absolutePath}") + sessionDir + } + + else -> { + throw IllegalArgumentException("Invalid session ID: $session") + } } - - override fun getSessionDir( - user: User?, - session: Session - ) = if (sessionPaths.containsKey(session)) { - sessionPaths[session]!! - } else { - getDataDir(user, session).apply { mkdirs() } + } + + override fun listSessions( + user: User?, + path: String + ): List { + log.debug("Listing sessions for user: ${user?.email}") + val globalSessions = listSessions(dataDir.resolve("global"), path) + val userSessions = if (user == null) listOf() else ApplicationServices.metadataStorageFactory(dataDir).listSessions( + path + ) + log.debug("Found ${globalSessions.size} global sessions and ${userSessions.size} user sessions for user: ${user?.email}") + return ((globalSessions.map { + try { + Session("G-$it") + } catch (e: Exception) { + null + } + }).toList() + (userSessions.map { + try { + Session("U-$it") + } catch (e: Exception) { + null + } + }).toList()).filterNotNull() + } + + override fun setJson( + user: User?, + session: Session, + filename: String, + settings: T + ) = setJson(getDataDir(user, session), filename, settings) + + private fun setJson(sessionDir: File, filename: String, settings: T): T { + log.debug("Setting JSON for session directory: ${sessionDir.absolutePath}, filename: $filename") + val settingsFile = sessionDir.resolve(filename).apply { parentFile.mkdirs() } + JsonUtil.objectMapper().writeValue(settingsFile, settings) + return settings + } + + override fun updateMessage( + user: User?, + session: Session, + messageId: String, + value: String + ) { + Session.validateSessionId(session) + log.debug("Updating message for session: ${session.sessionId}, messageId: $messageId, user: ${user?.email}") + val file = + getDataDir(user, session).resolve("messages/$messageId.json") + .apply { parentFile.mkdirs() } + if (!file.exists()) { + file.parentFile.mkdirs() + addMessageID(user, session, messageId) } - - override fun getDataDir( - user: User?, - session: Session - ): File { - Session.validateSessionId(session) - log.debug("Getting data directory for session: ${session.sessionId}, user: ${user?.email}") - val parts = session.sessionId.split("-") - return when (parts.size) { - 3 -> { - val root = when { - parts[0] == "G" -> dataDir.resolve("global") - parts[0] == "U" -> dataDir.resolve("user-sessions/$user") - else -> throw IllegalArgumentException("Invalid session ID: $session") - } - val dateDir = File(root, parts[1]) - val sessionDir = File(dateDir, parts[2]) - log.debug("Session directory for session: ${session.sessionId} is ${sessionDir.absolutePath}") - sessionDir - } - - 2 -> { - val dateDir = dataDir.resolve("global").resolve(parts[0]) - val sessionDir = dateDir.resolve(parts[1]) - log.debug("Session directory for session: ${session.sessionId} is ${sessionDir.absolutePath}") - sessionDir - } - - else -> { - throw IllegalArgumentException("Invalid session ID: $session") - } - } - } - - override fun listSessions( - user: User?, - path: String - ): List { - log.debug("Listing sessions for user: ${user?.email}") - val globalSessions = listSessions(dataDir.resolve("global"), path) - val userSessions = if (user == null) listOf() else ApplicationServices.metadataStorageFactory(dataDir).listSessions( - path - ) - log.debug("Found ${globalSessions.size} global sessions and ${userSessions.size} user sessions for user: ${user?.email}") - return ((globalSessions.map { - try { - Session("G-$it") - } catch (e: Exception) { - null - } - }).toList() + (userSessions.map { - try { - Session("U-$it") - } catch (e: Exception) { - null - } - }).toList()).filterNotNull() - } - - override fun setJson( - user: User?, - session: Session, - filename: String, - settings: T - ) = setJson(getDataDir(user, session), filename, settings) - - private fun setJson(sessionDir: File, filename: String, settings: T): T { - log.debug("Setting JSON for session directory: ${sessionDir.absolutePath}, filename: $filename") - val settingsFile = sessionDir.resolve(filename).apply { parentFile.mkdirs() } - JsonUtil.objectMapper().writeValue(settingsFile, settings) - return settings + JsonUtil.objectMapper().writeValue(file, value) + } + + protected open fun addMessageID( + user: User?, + session: Session, + messageId: String + ) { + synchronized(this) { + log.debug("Adding message ID for session: ${session.sessionId}, messageId: $messageId, user: ${user?.email}") + setMessageIds(user, session, getMessageIds(user, session) + messageId) } + } - override fun updateMessage( - user: User?, - session: Session, - messageId: String, - value: String - ) { - Session.validateSessionId(session) - log.debug("Updating message for session: ${session.sessionId}, messageId: $messageId, user: ${user?.email}") - val file = - getDataDir(user, session).resolve("messages/$messageId.json") - .apply { parentFile.mkdirs() } - if (!file.exists()) { - file.parentFile.mkdirs() - addMessageID(user, session, messageId) - } - JsonUtil.objectMapper().writeValue(file, value) - } - - protected open fun addMessageID( - user: User?, - session: Session, - messageId: String - ) { - synchronized(this) { - log.debug("Adding message ID for session: ${session.sessionId}, messageId: $messageId, user: ${user?.email}") - setMessageIds(user, session, getMessageIds(user, session) + messageId) - } + override fun userRoot(user: User?) = dataDir.resolve("users").resolve( + if (user?.email != null) { + user.email + } else { + throw IllegalArgumentException("User required for private session") } + ).apply { mkdirs() } - override fun userRoot(user: User?) = dataDir.resolve("users").resolve( - if (user?.email != null) { - user.email - } else { - throw IllegalArgumentException("User required for private session") - } - ).apply { mkdirs() } - - override fun deleteSession(user: User?, session: Session) { - Session.validateSessionId(session) - log.debug("Deleting session: ${session.sessionId}, user: ${user?.email}") - val sessionDir = getDataDir(user, session) - ApplicationServices.metadataStorageFactory(dataDir).deleteSession(user, session) - sessionDir.deleteRecursively() - } + override fun deleteSession(user: User?, session: Session) { + Session.validateSessionId(session) + log.debug("Deleting session: ${session.sessionId}, user: ${user?.email}") + val sessionDir = getDataDir(user, session) + ApplicationServices.metadataStorageFactory(dataDir).deleteSession(user, session) + sessionDir.deleteRecursively() + } - override fun listSessions(dir: File, path: String): List = ApplicationServices.metadataStorageFactory(dataDir).listSessions(path) + override fun listSessions(dir: File, path: String): List = ApplicationServices.metadataStorageFactory(dataDir).listSessions(path) - override fun getSessionName( - user: User?, - session: Session - ): String = ApplicationServices.metadataStorageFactory(dataDir).getSessionName(user, session) + override fun getSessionName( + user: User?, + session: Session + ): String = ApplicationServices.metadataStorageFactory(dataDir).getSessionName(user, session) - override fun getMessageIds( - user: User?, - session: Session - ): List = ApplicationServices.metadataStorageFactory(dataDir).getMessageIds(user, session) + override fun getMessageIds( + user: User?, + session: Session + ): List = ApplicationServices.metadataStorageFactory(dataDir).getMessageIds(user, session) - override fun setMessageIds( - user: User?, - session: Session, - ids: List - ) = ApplicationServices.metadataStorageFactory(dataDir).setMessageIds(user, session, ids) + override fun setMessageIds( + user: User?, + session: Session, + ids: List + ) = ApplicationServices.metadataStorageFactory(dataDir).setMessageIds(user, session, ids) - override fun getSessionTime( - user: User?, - session: Session - ): Date? = ApplicationServices.metadataStorageFactory(dataDir).getSessionTime(user, session) + override fun getSessionTime( + user: User?, + session: Session + ): Date? = ApplicationServices.metadataStorageFactory(dataDir).getSessionTime(user, session) - companion object { + companion object { - val log = org.slf4j.LoggerFactory.getLogger(DataStorage::class.java) - val sessionPaths = mutableMapOf() + val log = org.slf4j.LoggerFactory.getLogger(DataStorage::class.java) + val sessionPaths = mutableMapOf() - } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/MetadataStorage.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/MetadataStorage.kt index 9e9a0594..3dadd60f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/MetadataStorage.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/MetadataStorage.kt @@ -13,128 +13,128 @@ import kotlin.reflect.typeOf class MetadataStorage(private val dataDir: File) : MetadataStorageInterface { - private val log = org.slf4j.LoggerFactory.getLogger(MetadataStorage::class.java) - - override fun getSessionName(user: User?, session: Session): String { - log.debug("Fetching session name for session: ${session.sessionId}, user: ${user?.email}") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "settings.json") - if (settings.containsKey("name")) return settings["name"] as String - val userMessage = messageFiles(session, sessionDir).entries.minByOrNull { it.key.lastModified() }?.value - return if (null != userMessage) { - setJson(sessionDir, "settings.json", settings.plus("name" to userMessage)) - log.debug("Session name for session: ${session.sessionId} is $userMessage") - userMessage - } else { - log.debug("Session ${session.sessionId} has no messages") - session.sessionId - } - } - - override fun setSessionName(user: User?, session: Session, name: String) { - log.debug("Setting session name for session: ${session.sessionId}, user: ${user?.email} to $name") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "settings.json") - setJson(sessionDir, "settings.json", settings.plus("name" to name)) + private val log = org.slf4j.LoggerFactory.getLogger(MetadataStorage::class.java) + + override fun getSessionName(user: User?, session: Session): String { + log.debug("Fetching session name for session: ${session.sessionId}, user: ${user?.email}") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "settings.json") + if (settings.containsKey("name")) return settings["name"] as String + val userMessage = messageFiles(session, sessionDir).entries.minByOrNull { it.key.lastModified() }?.value + return if (null != userMessage) { + setJson(sessionDir, "settings.json", settings.plus("name" to userMessage)) + log.debug("Session name for session: ${session.sessionId} is $userMessage") + userMessage + } else { + log.debug("Session ${session.sessionId} has no messages") + session.sessionId } - - - override fun getMessageIds(user: User?, session: Session): List { - log.debug("Fetching message IDs for session: ${session.sessionId}, user: ${user?.email}") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "internal.json") - if (settings.containsKey("ids")) return settings["ids"].toString().split(",").toList() - val ids = messageFiles(session, sessionDir).entries.sortedBy { it.key.lastModified() } - .map { it.key.nameWithoutExtension }.toList() - setJson(sessionDir, "internal.json", settings.plus("ids" to ids.joinToString(","))) - log.debug("Message IDs for session: ${session.sessionId} are $ids") - return ids + } + + override fun setSessionName(user: User?, session: Session, name: String) { + log.debug("Setting session name for session: ${session.sessionId}, user: ${user?.email} to $name") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "settings.json") + setJson(sessionDir, "settings.json", settings.plus("name" to name)) + } + + + override fun getMessageIds(user: User?, session: Session): List { + log.debug("Fetching message IDs for session: ${session.sessionId}, user: ${user?.email}") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "internal.json") + if (settings.containsKey("ids")) return settings["ids"].toString().split(",").toList() + val ids = messageFiles(session, sessionDir).entries.sortedBy { it.key.lastModified() } + .map { it.key.nameWithoutExtension }.toList() + setJson(sessionDir, "internal.json", settings.plus("ids" to ids.joinToString(","))) + log.debug("Message IDs for session: ${session.sessionId} are $ids") + return ids + } + + override fun setMessageIds(user: User?, session: Session, ids: List) { + log.debug("Setting message IDs for session: ${session.sessionId}, user: ${user?.email} to $ids") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "internal.json") + setJson(sessionDir, "internal.json", settings.plus("ids" to ids.joinToString(","))) + } + + override fun getSessionTime(user: User?, session: Session): Date? { + log.debug("Fetching session time for session: ${session.sessionId}, user: ${user?.email}") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "internal.json") + val dateFormat = SimpleDateFormat.getDateTimeInstance() + if (settings.containsKey("time")) return dateFormat.parse(settings["time"] as String) + val messageFiles = messageFiles(session, sessionDir) + val file = messageFiles.entries.minByOrNull { it.key.lastModified() }?.key + return if (null != file) { + val date = Date(file.lastModified()) + setJson(sessionDir, "internal.json", settings.plus("time" to dateFormat.format(date))) + log.debug("Session time for session: ${session.sessionId} is $date") + date + } else { + log.debug("Session ${session.sessionId} has no messages") + null } - - override fun setMessageIds(user: User?, session: Session, ids: List) { - log.debug("Setting message IDs for session: ${session.sessionId}, user: ${user?.email} to $ids") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "internal.json") - setJson(sessionDir, "internal.json", settings.plus("ids" to ids.joinToString(","))) - } - - override fun getSessionTime(user: User?, session: Session): Date? { - log.debug("Fetching session time for session: ${session.sessionId}, user: ${user?.email}") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "internal.json") - val dateFormat = SimpleDateFormat.getDateTimeInstance() - if (settings.containsKey("time")) return dateFormat.parse(settings["time"] as String) - val messageFiles = messageFiles(session, sessionDir) - val file = messageFiles.entries.minByOrNull { it.key.lastModified() }?.key - return if (null != file) { - val date = Date(file.lastModified()) - setJson(sessionDir, "internal.json", settings.plus("time" to dateFormat.format(date))) - log.debug("Session time for session: ${session.sessionId} is $date") - date + } + + override fun setSessionTime(user: User?, session: Session, time: Date) { + log.debug("Setting session time for session: ${session.sessionId}, user: ${user?.email} to $time") + val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) + val settings = getSettings(sessionDir, "internal.json") + val dateFormat = SimpleDateFormat.getDateTimeInstance() + setJson(sessionDir, "internal.json", settings.plus("time" to dateFormat.format(time))) + } + + + override fun listSessions(path: String): List { + log.debug("Listing sessions in dataDir.absolutePath}") + val files = dataDir.listFiles() + ?.flatMap { it.listFiles()?.toList() ?: listOf() } + ?.filter { sessionDir -> + val resolve = sessionDir.resolve("info.json") + if (!resolve.exists()) return@filter false + val infoJson = resolve.readText() + val infoData = JsonUtil.fromJson>(infoJson, typeOf>().javaType) + path == infoData["path"] + }?.sortedBy { it.lastModified() } ?: listOf() + log.debug("Found ${files.size} sessions in directory: ${dataDir.absolutePath}") + return files.map { it.parentFile.name + "-" + it.name } + } + + private fun getSettings(sessionDir: File, filename: String): Map<*, *> { + val settingsFile = sessionDir.resolve(filename) + return if (!settingsFile.exists()) mapOf() + else JsonUtil.objectMapper().readValue(settingsFile, Map::class.java) as Map<*, *> + } + + private fun setJson(sessionDir: File, filename: String, settings: T): T { + log.debug("Setting JSON for session directory: ${sessionDir.absolutePath}, filename: $filename") + val settingsFile = sessionDir.resolve(filename).apply { parentFile.mkdirs() } + JsonUtil.objectMapper().writeValue(settingsFile, settings) + return settings + } + + private fun messageFiles(session: Session, sessionDir: File): Map { + return sessionDir.resolve("messages") + .apply { mkdirs() }.listFiles() + ?.filter { file -> file.isFile } + ?.map { messageFile -> + val fileText = messageFile.readText() + val split = fileText.split("

") + if (split.size < 2) { + log.debug("Session ${session.sessionId} has no messages in file ${messageFile.name}") + messageFile to "" } else { - log.debug("Session ${session.sessionId} has no messages") - null + val stringList = split[1].split("

") + if (stringList.isEmpty()) { + log.debug("Session ${session.sessionId} has no messages in file ${messageFile.name}") + messageFile to "" + } else { + messageFile to stringList.first() + } } - } - - override fun setSessionTime(user: User?, session: Session, time: Date) { - log.debug("Setting session time for session: ${session.sessionId}, user: ${user?.email} to $time") - val sessionDir: File = ApplicationServices.dataStorageFactory.invoke(dataDir).getDataDir(user, session) - val settings = getSettings(sessionDir, "internal.json") - val dateFormat = SimpleDateFormat.getDateTimeInstance() - setJson(sessionDir, "internal.json", settings.plus("time" to dateFormat.format(time))) - } - - - override fun listSessions(path: String): List { - log.debug("Listing sessions in dataDir.absolutePath}") - val files = dataDir.listFiles() - ?.flatMap { it.listFiles()?.toList() ?: listOf() } - ?.filter { sessionDir -> - val resolve = sessionDir.resolve("info.json") - if (!resolve.exists()) return@filter false - val infoJson = resolve.readText() - val infoData = JsonUtil.fromJson>(infoJson, typeOf>().javaType) - path == infoData["path"] - }?.sortedBy { it.lastModified() } ?: listOf() - log.debug("Found ${files.size} sessions in directory: ${dataDir.absolutePath}") - return files.map { it.parentFile.name + "-" + it.name } - } - - private fun getSettings(sessionDir: File, filename: String): Map<*, *> { - val settingsFile = sessionDir.resolve(filename) - return if (!settingsFile.exists()) mapOf() - else JsonUtil.objectMapper().readValue(settingsFile, Map::class.java) as Map<*, *> - } - - private fun setJson(sessionDir: File, filename: String, settings: T): T { - log.debug("Setting JSON for session directory: ${sessionDir.absolutePath}, filename: $filename") - val settingsFile = sessionDir.resolve(filename).apply { parentFile.mkdirs() } - JsonUtil.objectMapper().writeValue(settingsFile, settings) - return settings - } - - private fun messageFiles(session: Session, sessionDir: File): Map { - return sessionDir.resolve("messages") - .apply { mkdirs() }.listFiles() - ?.filter { file -> file.isFile } - ?.map { messageFile -> - val fileText = messageFile.readText() - val split = fileText.split("

") - if (split.size < 2) { - log.debug("Session ${session.sessionId} has no messages in file ${messageFile.name}") - messageFile to "" - } else { - val stringList = split[1].split("

") - if (stringList.isEmpty()) { - log.debug("Session ${session.sessionId} has no messages in file ${messageFile.name}") - messageFile to "" - } else { - messageFile to stringList.first() - } - } - }?.filter { it.second.isNotEmpty() }?.toList()?.toMap() ?: mapOf() - } + }?.filter { it.second.isNotEmpty() }?.toList()?.toMap() ?: mapOf() + } - override fun deleteSession(user: User?, session: Session) {} + override fun deleteSession(user: User?, session: Session) {} } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt index bc649477..69690b2f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UsageManager.kt @@ -14,188 +14,188 @@ import java.util.concurrent.TimeUnit open class UsageManager(val root: File) : UsageInterface { - private val scheduler = Executors.newSingleThreadScheduledExecutor() - private val txLogFile = File(root, "log.csv") - - @Volatile - private var txLogFileWriter: FileWriter? - private val usagePerSession = ConcurrentHashMap() - private val sessionsByUser = ConcurrentHashMap>() - private val usersBySession = ConcurrentHashMap>() - - init { - txLogFile.parentFile.mkdirs() - loadFromLog(txLogFile) - txLogFileWriter = FileWriter(txLogFile, true) - scheduler.scheduleAtFixedRate({ saveCounters() }, 1, 1, TimeUnit.HOURS) - } - - @Suppress("MemberVisibilityCanBePrivate") - private fun loadFromLog(file: File) { - if (file.exists()) { - try { - file.readLines().forEach { line -> - val (sessionId, user, model, value, direction) = line.split(",") - try { - val modelEnum = listOf( - ChatModel.values(), - CompletionModels.values(), - EditModels.values(), - EmbeddingModels.values() - ).flatMap { it.values }.find { model == it.modelName } - ?: throw RuntimeException("Unknown model $model") - when (direction) { - "input" -> incrementUsage( - Session(sessionId), - User(email = user), - modelEnum, - ApiModel.Usage(prompt_tokens = value.toLong()) - ) - - "output" -> incrementUsage( - Session(sessionId), - User(email = user), - modelEnum, - ApiModel.Usage(completion_tokens = value.toLong()) - ) - - "cost" -> incrementUsage( - session = Session(sessionId = sessionId), - user = User(email = user), - model = modelEnum, - tokens = ApiModel.Usage(cost = value.toDouble()) - ) - - else -> throw RuntimeException("Unknown direction $direction") - } - } catch (e: Exception) { - //log.debug("Error loading log line: ${e.message}") - } - } - } catch (e: Exception) { - log.warn("Error loading log file", e) + private val scheduler = Executors.newSingleThreadScheduledExecutor() + private val txLogFile = File(root, "log.csv") + + @Volatile + private var txLogFileWriter: FileWriter? + private val usagePerSession = ConcurrentHashMap() + private val sessionsByUser = ConcurrentHashMap>() + private val usersBySession = ConcurrentHashMap>() + + init { + txLogFile.parentFile.mkdirs() + loadFromLog(txLogFile) + txLogFileWriter = FileWriter(txLogFile, true) + scheduler.scheduleAtFixedRate({ saveCounters() }, 1, 1, TimeUnit.HOURS) + } + + @Suppress("MemberVisibilityCanBePrivate") + private fun loadFromLog(file: File) { + if (file.exists()) { + try { + file.readLines().forEach { line -> + val (sessionId, user, model, value, direction) = line.split(",") + try { + val modelEnum = listOf( + ChatModel.values(), + CompletionModels.values(), + EditModels.values(), + EmbeddingModels.values() + ).flatMap { it.values }.find { model == it.modelName } + ?: throw RuntimeException("Unknown model $model") + when (direction) { + "input" -> incrementUsage( + Session(sessionId), + User(email = user), + modelEnum, + ApiModel.Usage(prompt_tokens = value.toLong()) + ) + + "output" -> incrementUsage( + Session(sessionId), + User(email = user), + modelEnum, + ApiModel.Usage(completion_tokens = value.toLong()) + ) + + "cost" -> incrementUsage( + session = Session(sessionId = sessionId), + user = User(email = user), + model = modelEnum, + tokens = ApiModel.Usage(cost = value.toDouble()) + ) + + else -> throw RuntimeException("Unknown direction $direction") } + } catch (e: Exception) { + //log.debug("Error loading log line: ${e.message}") + } } + } catch (e: Exception) { + log.warn("Error loading log file", e) + } } - - @Suppress("MemberVisibilityCanBePrivate") - private fun writeCompactLog(file: File) { - FileWriter(file).use { writer -> - usagePerSession.forEach { (sessionId, usage) -> - val apiKey = usersBySession[sessionId]?.firstOrNull() - usage.tokensPerModel.forEach { (model, counter) -> - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.inputTokens.get()},input\n") - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.outputTokens.get()},output\n") - writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.cost.get()},cost\n") - } - } - writer.flush() + } + + @Suppress("MemberVisibilityCanBePrivate") + private fun writeCompactLog(file: File) { + FileWriter(file).use { writer -> + usagePerSession.forEach { (sessionId, usage) -> + val apiKey = usersBySession[sessionId]?.firstOrNull() + usage.tokensPerModel.forEach { (model, counter) -> + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.inputTokens.get()},input\n") + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.outputTokens.get()},output\n") + writer.write("$sessionId,${apiKey},${model.model.modelName},${counter.cost.get()},cost\n") } + } + writer.flush() } - - private fun saveCounters() { - txLogFileWriter = FileWriter(txLogFile, true) - val timedFile = File(txLogFile.absolutePath + "." + System.currentTimeMillis()) - writeCompactLog(timedFile) - val swapFile = File(txLogFile.absolutePath + ".old") - synchronized(txLogFile) { - try { - txLogFileWriter?.close() - } catch (e: Exception) { - log.warn("Error closing log file", e) - } - try { - txLogFile.renameTo(swapFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - try { - timedFile.renameTo(txLogFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - try { - swapFile.renameTo(timedFile) - } catch (e: Exception) { - log.warn("Error renaming log file", e) - } - txLogFileWriter = FileWriter(txLogFile, true) - } - val text = JsonUtil.toJson(usagePerSession) - File(root, "counters.json").writeText(text) - val toClean = txLogFile.parentFile.listFiles() - ?.filter { it.name.startsWith(txLogFile.name) && it.name != txLogFile.absolutePath } - ?.sortedBy { it.lastModified() } // oldest first - ?.dropLast(2) // keep 2 newest - ?.drop(2) // keep 2 oldest - toClean?.forEach { it.delete() } + } + + private fun saveCounters() { + txLogFileWriter = FileWriter(txLogFile, true) + val timedFile = File(txLogFile.absolutePath + "." + System.currentTimeMillis()) + writeCompactLog(timedFile) + val swapFile = File(txLogFile.absolutePath + ".old") + synchronized(txLogFile) { + try { + txLogFileWriter?.close() + } catch (e: Exception) { + log.warn("Error closing log file", e) + } + try { + txLogFile.renameTo(swapFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + try { + timedFile.renameTo(txLogFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + try { + swapFile.renameTo(timedFile) + } catch (e: Exception) { + log.warn("Error renaming log file", e) + } + txLogFileWriter = FileWriter(txLogFile, true) } - - override fun incrementUsage( - session: Session, - apiKey: String?, - model: OpenAIModel, - tokens: ApiModel.Usage - ) { - usagePerSession.computeIfAbsent(session) { UsageCounters() } - .tokensPerModel.computeIfAbsent(UsageKey(session, apiKey, model)) { UsageValues() } - .addAndGet(tokens) - if (apiKey != null) { - sessionsByUser.computeIfAbsent(apiKey) { HashSet() }.add(session) - } - try { - val txLogFileWriter = txLogFileWriter - if (null != txLogFileWriter) { - synchronized(txLogFile) { - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.prompt_tokens},input\n") - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.completion_tokens},output\n") - txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.cost},cost\n") - txLogFileWriter.flush() - } - } - } catch (e: Exception) { - log.warn("Error incrementing usage", e) - } - } - - override fun getUserUsageSummary(apiKey: String): Map { - return sessionsByUser[apiKey]?.flatMap { sessionId -> - val usage = usagePerSession[sessionId] - usage?.tokensPerModel?.entries?.map { (model, counter) -> - model.model to counter.toUsage() - } ?: emptyList() - }?.groupBy { it.first }?.mapValues { - it.value.map { it.second }.reduce { a, b -> - ApiModel.Usage( - prompt_tokens = a.prompt_tokens + b.prompt_tokens, - completion_tokens = a.completion_tokens + b.completion_tokens, - cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) - ) - } - } ?: emptyMap() - } - - override fun getSessionUsageSummary(session: Session): Map = - usagePerSession[session]?.tokensPerModel?.entries?.map { (model, counter) -> - model.model to counter.toUsage() - }?.groupBy { it.first }?.mapValues { - it.value.map { it.second }.reduce { a, b -> - ApiModel.Usage( - prompt_tokens = a.prompt_tokens + b.prompt_tokens, - completion_tokens = a.completion_tokens + b.completion_tokens, - cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) - ) - } - } ?: emptyMap() - - override fun clear() { - usagePerSession.clear() - sessionsByUser.clear() - usersBySession.clear() - saveCounters() + val text = JsonUtil.toJson(usagePerSession) + File(root, "counters.json").writeText(text) + val toClean = txLogFile.parentFile.listFiles() + ?.filter { it.name.startsWith(txLogFile.name) && it.name != txLogFile.absolutePath } + ?.sortedBy { it.lastModified() } // oldest first + ?.dropLast(2) // keep 2 newest + ?.drop(2) // keep 2 oldest + toClean?.forEach { it.delete() } + } + + override fun incrementUsage( + session: Session, + apiKey: String?, + model: OpenAIModel, + tokens: ApiModel.Usage + ) { + usagePerSession.computeIfAbsent(session) { UsageCounters() } + .tokensPerModel.computeIfAbsent(UsageKey(session, apiKey, model)) { UsageValues() } + .addAndGet(tokens) + if (apiKey != null) { + sessionsByUser.computeIfAbsent(apiKey) { HashSet() }.add(session) } - - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(UsageManager::class.java) + try { + val txLogFileWriter = txLogFileWriter + if (null != txLogFileWriter) { + synchronized(txLogFile) { + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.prompt_tokens},input\n") + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.completion_tokens},output\n") + txLogFileWriter.write("$session,${apiKey},${model.modelName},${tokens.cost},cost\n") + txLogFileWriter.flush() + } + } + } catch (e: Exception) { + log.warn("Error incrementing usage", e) } + } + + override fun getUserUsageSummary(apiKey: String): Map { + return sessionsByUser[apiKey]?.flatMap { sessionId -> + val usage = usagePerSession[sessionId] + usage?.tokensPerModel?.entries?.map { (model, counter) -> + model.model to counter.toUsage() + } ?: emptyList() + }?.groupBy { it.first }?.mapValues { + it.value.map { it.second }.reduce { a, b -> + ApiModel.Usage( + prompt_tokens = a.prompt_tokens + b.prompt_tokens, + completion_tokens = a.completion_tokens + b.completion_tokens, + cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) + ) + } + } ?: emptyMap() + } + + override fun getSessionUsageSummary(session: Session): Map = + usagePerSession[session]?.tokensPerModel?.entries?.map { (model, counter) -> + model.model to counter.toUsage() + }?.groupBy { it.first }?.mapValues { + it.value.map { it.second }.reduce { a, b -> + ApiModel.Usage( + prompt_tokens = a.prompt_tokens + b.prompt_tokens, + completion_tokens = a.completion_tokens + b.completion_tokens, + cost = (a.cost ?: 0.0) + (b.cost ?: 0.0) + ) + } + } ?: emptyMap() + + override fun clear() { + usagePerSession.clear() + sessionsByUser.clear() + usersBySession.clear() + saveCounters() + } + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(UsageManager::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UserSettingsManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UserSettingsManager.kt index a151a2fb..ab1ca085 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UserSettingsManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/file/UserSettingsManager.kt @@ -9,41 +9,41 @@ import java.io.File open class UserSettingsManager : UserSettingsInterface { - private val userSettings = HashMap() - private val userConfigDirectory by lazy { dataStorageRoot.resolve("users").apply { mkdirs() } } + private val userSettings = HashMap() + private val userConfigDirectory by lazy { dataStorageRoot.resolve("users").apply { mkdirs() } } - override fun getUserSettings(user: User): UserSettings { - log.debug("Retrieving user settings for user: {}", user) - return userSettings.getOrPut(user) { - val file = File(userConfigDirectory, "$user.json") - if (file.exists()) { - try { - log.info("Loading existing user settings for user: {} from file: {}", user, file) - return@getOrPut JsonUtil.fromJson(file.readText(), UserSettings::class.java) - } catch (e: Throwable) { - log.error("Failed to load user settings for user: {} from file: {}. Creating new settings.", user, file, e) - } - } - log.info("User settings file not found for user: {}. Creating new settings at: {}", user, file) - return@getOrPut UserSettings() - } - } - - override fun updateUserSettings(user: User, settings: UserSettings) { - log.debug("Updating user settings for user: {}", user) - userSettings[user] = settings - val file = File(userConfigDirectory, "$user.json") - file.parentFile.mkdirs() + override fun getUserSettings(user: User): UserSettings { + log.debug("Retrieving user settings for user: {}", user) + return userSettings.getOrPut(user) { + val file = File(userConfigDirectory, "$user.json") + if (file.exists()) { try { - file.writeText(JsonUtil.toJson(settings)) - log.info("Successfully updated user settings for user: {} at file: {}", user, file) - } catch (e: Exception) { - log.error("Failed to write user settings for user: {} to file: {}", user, file, e) + log.info("Loading existing user settings for user: {} from file: {}", user, file) + return@getOrPut JsonUtil.fromJson(file.readText(), UserSettings::class.java) + } catch (e: Throwable) { + log.error("Failed to load user settings for user: {} from file: {}. Creating new settings.", user, file, e) } + } + log.info("User settings file not found for user: {}. Creating new settings at: {}", user, file) + return@getOrPut UserSettings() } + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(UserSettingsManager::class.java) + override fun updateUserSettings(user: User, settings: UserSettings) { + log.debug("Updating user settings for user: {}", user) + userSettings[user] = settings + val file = File(userConfigDirectory, "$user.json") + file.parentFile.mkdirs() + try { + file.writeText(JsonUtil.toJson(settings)) + log.info("Successfully updated user settings for user: {} at file: {}", user, file) + } catch (e: Exception) { + log.error("Failed to write user settings for user: {} to file: {}", user, file, e) } + } + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(UserSettingsManager::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLMetadataStorage.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLMetadataStorage.kt index 987fb84c..9a56ee1a 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLMetadataStorage.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLMetadataStorage.kt @@ -11,21 +11,21 @@ import java.sql.Timestamp import java.util.* class HSQLMetadataStorage(private val dbFile: File) : MetadataStorageInterface { - private val log = LoggerFactory.getLogger(javaClass) + private val log = LoggerFactory.getLogger(javaClass) - private val connection: Connection by lazy { - log.info("Initializing HSQLMetadataStorage with database file: ${dbFile.absolutePath}") - Class.forName("org.hsqldb.jdbc.JDBCDriver") - val connection = DriverManager.getConnection("jdbc:hsqldb:file:${dbFile.absolutePath}/metadata;shutdown=true", "SA", "") - log.info("Database connection established successfully") - createSchema(connection) - connection - } + private val connection: Connection by lazy { + log.info("Initializing HSQLMetadataStorage with database file: ${dbFile.absolutePath}") + Class.forName("org.hsqldb.jdbc.JDBCDriver") + val connection = DriverManager.getConnection("jdbc:hsqldb:file:${dbFile.absolutePath}/metadata;shutdown=true", "SA", "") + log.info("Database connection established successfully") + createSchema(connection) + connection + } - private fun createSchema(connection: Connection) { - log.debug("Attempting to create database schema if not exists") - connection.createStatement().executeUpdate( - """ + private fun createSchema(connection: Connection) { + log.debug("Attempting to create database schema if not exists") + connection.createStatement().executeUpdate( + """ CREATE TABLE IF NOT EXISTS metadata ( session_id VARCHAR(255), user_email VARCHAR(255), @@ -35,150 +35,150 @@ class HSQLMetadataStorage(private val dbFile: File) : MetadataStorageInterface { PRIMARY KEY (session_id, user_email, key) ) """ - ) - log.info("Database schema creation completed") - } + ) + log.info("Database schema creation completed") + } - override fun getSessionName(user: User?, session: Session): String { - log.debug("Fetching session name for session: ${session.sessionId}, user: ${user?.email}") - val statement = connection.prepareStatement( - "SELECT value FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'name'" - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - val resultSet = statement.executeQuery() - return if (resultSet.next()) { - val name = resultSet.getString("value") - log.debug("Retrieved session name: $name for session: ${session.sessionId}") - name - } else { - session.sessionId - } + override fun getSessionName(user: User?, session: Session): String { + log.debug("Fetching session name for session: ${session.sessionId}, user: ${user?.email}") + val statement = connection.prepareStatement( + "SELECT value FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'name'" + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + val resultSet = statement.executeQuery() + return if (resultSet.next()) { + val name = resultSet.getString("value") + log.debug("Retrieved session name: $name for session: ${session.sessionId}") + name + } else { + session.sessionId } + } - override fun setSessionName(user: User?, session: Session, name: String) { - log.debug("Setting session name for session: ${session.sessionId}, user: ${user?.email} to $name") - val statement = connection.prepareStatement( - """ + override fun setSessionName(user: User?, session: Session, name: String) { + log.debug("Setting session name for session: ${session.sessionId}, user: ${user?.email} to $name") + val statement = connection.prepareStatement( + """ MERGE INTO metadata USING (VALUES(?, ?, ?, ?, ?)) AS vals(session_id, user_email, key, value, timestamp) ON metadata.session_id = vals.session_id AND metadata.user_email = vals.user_email AND metadata.key = vals.key WHEN MATCHED THEN UPDATE SET metadata.value = vals.value, metadata.timestamp = vals.timestamp WHEN NOT MATCHED THEN INSERT VALUES vals.session_id, vals.user_email, vals.key, vals.value, vals.timestamp """ - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - statement.setString(3, "name") - statement.setString(4, name) - statement.setTimestamp(5, Timestamp(System.currentTimeMillis())) - statement.executeUpdate() - log.info("Session name set successfully for session: ${session.sessionId}") - } + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + statement.setString(3, "name") + statement.setString(4, name) + statement.setTimestamp(5, Timestamp(System.currentTimeMillis())) + statement.executeUpdate() + log.info("Session name set successfully for session: ${session.sessionId}") + } - override fun getMessageIds(user: User?, session: Session): List { - log.debug("Fetching message IDs for session: ${session.sessionId}, user: ${user?.email}") - val statement = connection.prepareStatement( - "SELECT value FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'message_ids'" - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - val resultSet = statement.executeQuery() - return if (resultSet.next()) { - val ids = resultSet.getString("value").split(",") - log.debug("Retrieved ${ids.size} message IDs for session: ${session.sessionId}") - ids - } else { - log.debug("No message IDs found for session: ${session.sessionId}") - emptyList() - } + override fun getMessageIds(user: User?, session: Session): List { + log.debug("Fetching message IDs for session: ${session.sessionId}, user: ${user?.email}") + val statement = connection.prepareStatement( + "SELECT value FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'message_ids'" + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + val resultSet = statement.executeQuery() + return if (resultSet.next()) { + val ids = resultSet.getString("value").split(",") + log.debug("Retrieved ${ids.size} message IDs for session: ${session.sessionId}") + ids + } else { + log.debug("No message IDs found for session: ${session.sessionId}") + emptyList() } + } - override fun setMessageIds(user: User?, session: Session, ids: List) { - log.debug("Setting message IDs for session: ${session.sessionId}, user: ${user?.email} to $ids") - val statement = connection.prepareStatement( - """ + override fun setMessageIds(user: User?, session: Session, ids: List) { + log.debug("Setting message IDs for session: ${session.sessionId}, user: ${user?.email} to $ids") + val statement = connection.prepareStatement( + """ MERGE INTO metadata USING (VALUES(?, ?, ?, ?, ?)) AS vals(session_id, user_email, key, value, timestamp) ON metadata.session_id = vals.session_id AND metadata.user_email = vals.user_email AND metadata.key = vals.key WHEN MATCHED THEN UPDATE SET metadata.value = vals.value, metadata.timestamp = vals.timestamp WHEN NOT MATCHED THEN INSERT VALUES vals.session_id, vals.user_email, vals.key, vals.value, vals.timestamp """ - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - statement.setString(3, "message_ids") - statement.setString(4, ids.joinToString(",")) - statement.setTimestamp(5, Timestamp(System.currentTimeMillis())) - statement.executeUpdate() - log.info("Set ${ids.size} message IDs for session: ${session.sessionId}") - } + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + statement.setString(3, "message_ids") + statement.setString(4, ids.joinToString(",")) + statement.setTimestamp(5, Timestamp(System.currentTimeMillis())) + statement.executeUpdate() + log.info("Set ${ids.size} message IDs for session: ${session.sessionId}") + } - override fun getSessionTime(user: User?, session: Session): Date? { - log.debug("Fetching session time for session: ${session.sessionId}, user: ${user?.email}") - val statement = connection.prepareStatement( - "SELECT value, timestamp FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'session_time'" - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - val resultSet = statement.executeQuery() - return if (resultSet.next()) { - val time = resultSet.getString("value") - try { - Date(time.toLong()).also { - log.debug("Retrieved session time: $it for session: ${session.sessionId}") - } - } catch (e: NumberFormatException) { - log.warn("Invalid session time value: $time, falling back to timestamp for session: ${session.sessionId}") - resultSet.getTimestamp("timestamp") - } - } else { - Date() + override fun getSessionTime(user: User?, session: Session): Date? { + log.debug("Fetching session time for session: ${session.sessionId}, user: ${user?.email}") + val statement = connection.prepareStatement( + "SELECT value, timestamp FROM metadata WHERE session_id = ? AND user_email = ? AND key = 'session_time'" + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + val resultSet = statement.executeQuery() + return if (resultSet.next()) { + val time = resultSet.getString("value") + try { + Date(time.toLong()).also { + log.debug("Retrieved session time: $it for session: ${session.sessionId}") } + } catch (e: NumberFormatException) { + log.warn("Invalid session time value: $time, falling back to timestamp for session: ${session.sessionId}") + resultSet.getTimestamp("timestamp") + } + } else { + Date() } + } - override fun setSessionTime(user: User?, session: Session, time: Date) { - log.debug("Setting session time for session: ${session.sessionId}, user: ${user?.email} to $time") - val statement = connection.prepareStatement( - """ + override fun setSessionTime(user: User?, session: Session, time: Date) { + log.debug("Setting session time for session: ${session.sessionId}, user: ${user?.email} to $time") + val statement = connection.prepareStatement( + """ MERGE INTO metadata USING (VALUES(?, ?, ?, ?, ?)) AS vals(session_id, user_email, key, value, timestamp) ON metadata.session_id = vals.session_id AND metadata.user_email = vals.user_email AND metadata.key = vals.key WHEN MATCHED THEN UPDATE SET metadata.value = vals.value, metadata.timestamp = vals.timestamp WHEN NOT MATCHED THEN INSERT VALUES vals.session_id, vals.user_email, vals.key, vals.value, vals.timestamp """ - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - statement.setString(3, "session_time") - statement.setString(4, time.time.toString()) - statement.setTimestamp(5, Timestamp(time.time)) - statement.executeUpdate() - log.info("Session time set to $time for session: ${session.sessionId}") - } + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + statement.setString(3, "session_time") + statement.setString(4, time.time.toString()) + statement.setTimestamp(5, Timestamp(time.time)) + statement.executeUpdate() + log.info("Session time set to $time for session: ${session.sessionId}") + } - override fun listSessions(path: String): List { - log.debug("Listing sessions for path: $path") - val statement = connection.prepareStatement( - "SELECT DISTINCT session_id FROM metadata WHERE value = ? AND key = 'path'" - ) - statement.setString(1, path) - val resultSet = statement.executeQuery() - val sessions = mutableListOf() - while (resultSet.next()) { - sessions.add(resultSet.getString("session_id")) - } - log.info("Found ${sessions.size} sessions for path: $path") - return sessions + override fun listSessions(path: String): List { + log.debug("Listing sessions for path: $path") + val statement = connection.prepareStatement( + "SELECT DISTINCT session_id FROM metadata WHERE value = ? AND key = 'path'" + ) + statement.setString(1, path) + val resultSet = statement.executeQuery() + val sessions = mutableListOf() + while (resultSet.next()) { + sessions.add(resultSet.getString("session_id")) } + log.info("Found ${sessions.size} sessions for path: $path") + return sessions + } - override fun deleteSession(user: User?, session: Session) { - log.debug("Deleting session: ${session.sessionId}, user: ${user?.email}") - val statement = connection.prepareStatement( - "DELETE FROM metadata WHERE session_id = ? AND user_email = ?" - ) - statement.setString(1, session.sessionId) - statement.setString(2, user?.email ?: "") - statement.executeUpdate() - log.info("Deleted session: ${session.sessionId} for user: ${user?.email ?: "anonymous"}") - } + override fun deleteSession(user: User?, session: Session) { + log.debug("Deleting session: ${session.sessionId}, user: ${user?.email}") + val statement = connection.prepareStatement( + "DELETE FROM metadata WHERE session_id = ? AND user_email = ?" + ) + statement.setString(1, session.sessionId) + statement.setString(2, user?.email ?: "") + statement.executeUpdate() + log.info("Deleted session: ${session.sessionId} for user: ${user?.email ?: "anonymous"}") + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLUsageManager.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLUsageManager.kt index 82a744ac..c6918b13 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLUsageManager.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/hsql/HSQLUsageManager.kt @@ -16,19 +16,19 @@ import java.util.concurrent.atomic.AtomicLong class HSQLUsageManager(private val dbFile: File) : UsageInterface { - private val connection: Connection by lazy { - log.info("Initializing HSQLUsageManager with database file: ${dbFile.absolutePath}") - Class.forName("org.hsqldb.jdbc.JDBCDriver") - val connection = DriverManager.getConnection("jdbc:hsqldb:file:${dbFile.absolutePath}/usage;shutdown=true", "SA", "") - log.debug("Database connection established: $connection") - createSchema(connection) - connection - } - - private fun createSchema(connection: Connection) { - log.info("Creating database schema if not exists") - connection.createStatement().executeUpdate( - """ + private val connection: Connection by lazy { + log.info("Initializing HSQLUsageManager with database file: ${dbFile.absolutePath}") + Class.forName("org.hsqldb.jdbc.JDBCDriver") + val connection = DriverManager.getConnection("jdbc:hsqldb:file:${dbFile.absolutePath}/usage;shutdown=true", "SA", "") + log.debug("Database connection established: $connection") + createSchema(connection) + connection + } + + private fun createSchema(connection: Connection) { + log.info("Creating database schema if not exists") + connection.createStatement().executeUpdate( + """ CREATE TABLE IF NOT EXISTS usage ( session_id VARCHAR(255), api_key VARCHAR(255), @@ -40,136 +40,136 @@ class HSQLUsageManager(private val dbFile: File) : UsageInterface { PRIMARY KEY (session_id, api_key, model, prompt_tokens, completion_tokens, cost, datetime) ) """ - ) - } - - - private fun updateSchema() { - log.info("Updating database schema if needed") - // Add schema update logic here if needed - } - - private fun deleteSchema() { - log.info("Deleting database schema if exists") - connection.createStatement().executeUpdate("DROP TABLE IF EXISTS usage") - log.debug("Schema deleted") - } - - override fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) { - try { - log.debug("Incrementing usage for session: ${session.sessionId}, apiKey: $apiKey, model: ${model.modelName}") - val usageKey = UsageInterface.UsageKey(session, apiKey, model) - val usageValues = UsageInterface.UsageValues() //getUsageValues(usageKey) - usageValues.addAndGet(tokens) - saveUsageValues(usageKey, usageValues) - log.debug("Usage incremented for session: ${session.sessionId}, apiKey: $apiKey, model: ${model.modelName}") - } catch (e: Exception) { - log.error("Error incrementing usage", e) - } + ) + } + + + private fun updateSchema() { + log.info("Updating database schema if needed") + // Add schema update logic here if needed + } + + private fun deleteSchema() { + log.info("Deleting database schema if exists") + connection.createStatement().executeUpdate("DROP TABLE IF EXISTS usage") + log.debug("Schema deleted") + } + + override fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) { + try { + log.debug("Incrementing usage for session: ${session.sessionId}, apiKey: $apiKey, model: ${model.modelName}") + val usageKey = UsageInterface.UsageKey(session, apiKey, model) + val usageValues = UsageInterface.UsageValues() //getUsageValues(usageKey) + usageValues.addAndGet(tokens) + saveUsageValues(usageKey, usageValues) + log.debug("Usage incremented for session: ${session.sessionId}, apiKey: $apiKey, model: ${model.modelName}") + } catch (e: Exception) { + log.error("Error incrementing usage", e) } + } - override fun getUserUsageSummary(apiKey: String): Map { - log.debug("Executing SQL query to get user usage summary for apiKey: $apiKey") - val statement = connection.prepareStatement( - """ + override fun getUserUsageSummary(apiKey: String): Map { + log.debug("Executing SQL query to get user usage summary for apiKey: $apiKey") + val statement = connection.prepareStatement( + """ SELECT model, SUM(prompt_tokens), SUM(completion_tokens), SUM(cost) FROM usage WHERE api_key = ? GROUP BY model """ - ) - statement.setString(1, apiKey) - val resultSet = statement.executeQuery() - return generateUsageSummary(resultSet) - } - - override fun getSessionUsageSummary(session: Session): Map { - log.info("Getting session usage summary for session: ${session.sessionId}") - val statement = connection.prepareStatement( - """ + ) + statement.setString(1, apiKey) + val resultSet = statement.executeQuery() + return generateUsageSummary(resultSet) + } + + override fun getSessionUsageSummary(session: Session): Map { + log.info("Getting session usage summary for session: ${session.sessionId}") + val statement = connection.prepareStatement( + """ SELECT model, SUM(prompt_tokens), SUM(completion_tokens), SUM(cost) FROM usage WHERE session_id = ? GROUP BY model """ - ) - statement.setString(1, session.sessionId) - val resultSet = statement.executeQuery() - return generateUsageSummary(resultSet) - } - - override fun clear() { - log.debug("Executing SQL statement to clear all usage data") - connection.createStatement().executeUpdate("DELETE FROM usage") - } - - private fun getUsageValues(usageKey: UsageInterface.UsageKey): UsageInterface.UsageValues { - log.debug("Getting usage values for session: ${usageKey.session.sessionId}, apiKey: ${usageKey.apiKey}, model: ${usageKey.model.modelName}") - val statement = connection.prepareStatement( - """ + ) + statement.setString(1, session.sessionId) + val resultSet = statement.executeQuery() + return generateUsageSummary(resultSet) + } + + override fun clear() { + log.debug("Executing SQL statement to clear all usage data") + connection.createStatement().executeUpdate("DELETE FROM usage") + } + + private fun getUsageValues(usageKey: UsageInterface.UsageKey): UsageInterface.UsageValues { + log.debug("Getting usage values for session: ${usageKey.session.sessionId}, apiKey: ${usageKey.apiKey}, model: ${usageKey.model.modelName}") + val statement = connection.prepareStatement( + """ SELECT COALESCE(SUM(prompt_tokens), 0), COALESCE(SUM(completion_tokens), 0), COALESCE(SUM(cost), 0) FROM usage WHERE session_id = ? AND api_key = ? AND model = ? """ - ) - statement.setString(1, usageKey.session.sessionId) - statement.setString(2, usageKey.apiKey ?: "") - statement.setString(3, usageKey.model.toString()) - val resultSet = statement.executeQuery() - resultSet.next() - return UsageInterface.UsageValues( - AtomicLong(resultSet.getLong(1)), - AtomicLong(resultSet.getLong(2)), - AtomicDouble(resultSet.getDouble(3)) - ) - } - - private fun saveUsageValues(usageKey: UsageInterface.UsageKey, usageValues: UsageInterface.UsageValues) { - log.debug("Saving usage values for session: ${usageKey.session.sessionId}, apiKey: ${usageKey.apiKey}, model: ${usageKey.model.modelName}") - val statement = connection.prepareStatement( - """ + ) + statement.setString(1, usageKey.session.sessionId) + statement.setString(2, usageKey.apiKey ?: "") + statement.setString(3, usageKey.model.toString()) + val resultSet = statement.executeQuery() + resultSet.next() + return UsageInterface.UsageValues( + AtomicLong(resultSet.getLong(1)), + AtomicLong(resultSet.getLong(2)), + AtomicDouble(resultSet.getDouble(3)) + ) + } + + private fun saveUsageValues(usageKey: UsageInterface.UsageKey, usageValues: UsageInterface.UsageValues) { + log.debug("Saving usage values for session: ${usageKey.session.sessionId}, apiKey: ${usageKey.apiKey}, model: ${usageKey.model.modelName}") + val statement = connection.prepareStatement( + """ INSERT INTO usage (session_id, api_key, model, prompt_tokens, completion_tokens, cost, datetime) VALUES (?, ?, ?, ?, ?, ?, ?) """ - ) - statement.setString(1, usageKey.session.sessionId) - statement.setString(2, usageKey.apiKey ?: "") - statement.setString(3, usageKey.model.modelName) - statement.setLong(4, usageValues.inputTokens.get()) - statement.setLong(5, usageValues.outputTokens.get()) - statement.setDouble(6, usageValues.cost.get()) - statement.setTimestamp(7, Timestamp(System.currentTimeMillis())) - log.debug("Executing statement: $statement") - log.debug("With parameters: ${usageKey.session.sessionId}, ${usageKey.apiKey}, ${usageKey.model.modelName}, ${usageValues.inputTokens.get()}, ${usageValues.outputTokens.get()}, ${usageValues.cost.get()}") - statement.executeUpdate() - } - - private fun generateUsageSummary(resultSet: ResultSet): Map { - log.debug("Generating usage summary from result set") - val summary = mutableMapOf() - while (resultSet.next()) { - val string = resultSet.getString(1) - val model = openAIModel(string) ?: continue - val usage = ApiModel.Usage( - prompt_tokens = resultSet.getLong(2), - completion_tokens = resultSet.getLong(3), - cost = resultSet.getDouble(4) - ) - summary[model] = usage - } - return summary - } - - private fun openAIModel(string: String): OpenAIModel? { - log.debug("Retrieving OpenAI model for string: $string") - val model = ChatModel.values().filter { - it.key == string || it.value.modelName == string || it.value.name == string - }.toList().firstOrNull()?.second ?: return null - log.debug("OpenAI model retrieved: $model") - return model - } - - companion object { - private val log = LoggerFactory.getLogger(HSQLUsageManager::class.java) + ) + statement.setString(1, usageKey.session.sessionId) + statement.setString(2, usageKey.apiKey ?: "") + statement.setString(3, usageKey.model.modelName) + statement.setLong(4, usageValues.inputTokens.get()) + statement.setLong(5, usageValues.outputTokens.get()) + statement.setDouble(6, usageValues.cost.get()) + statement.setTimestamp(7, Timestamp(System.currentTimeMillis())) + log.debug("Executing statement: $statement") + log.debug("With parameters: ${usageKey.session.sessionId}, ${usageKey.apiKey}, ${usageKey.model.modelName}, ${usageValues.inputTokens.get()}, ${usageValues.outputTokens.get()}, ${usageValues.cost.get()}") + statement.executeUpdate() + } + + private fun generateUsageSummary(resultSet: ResultSet): Map { + log.debug("Generating usage summary from result set") + val summary = mutableMapOf() + while (resultSet.next()) { + val string = resultSet.getString(1) + val model = openAIModel(string) ?: continue + val usage = ApiModel.Usage( + prompt_tokens = resultSet.getLong(2), + completion_tokens = resultSet.getLong(3), + cost = resultSet.getDouble(4) + ) + summary[model] = usage } + return summary + } + + private fun openAIModel(string: String): OpenAIModel? { + log.debug("Retrieving OpenAI model for string: $string") + val model = ChatModel.values().filter { + it.key == string || it.value.modelName == string || it.value.name == string + }.toList().firstOrNull()?.second ?: return null + log.debug("OpenAI model retrieved: $model") + return model + } + + companion object { + private val log = LoggerFactory.getLogger(HSQLUsageManager::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/ApplicationServicesConfig.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/ApplicationServicesConfig.kt index 0febf3cf..266db019 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/ApplicationServicesConfig.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/ApplicationServicesConfig.kt @@ -4,14 +4,14 @@ import java.io.File object ApplicationServicesConfig { - var isLocked: Boolean = false - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } - var dataStorageRoot: File = File(System.getProperty("user.home"), ".skyenet") - set(value) { - require(!isLocked) { "ApplicationServices is locked" } - field = value - } + var isLocked: Boolean = false + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } + var dataStorageRoot: File = File(System.getProperty("user.home"), ".skyenet") + set(value) { + require(!isLocked) { "ApplicationServices is locked" } + field = value + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthenticationInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthenticationInterface.kt index 0f227dc3..77d1ea46 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthenticationInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthenticationInterface.kt @@ -1,13 +1,13 @@ package com.simiacryptus.skyenet.core.platform.model interface AuthenticationInterface { - fun getUser(accessToken: String?): User? + fun getUser(accessToken: String?): User? - fun putUser(accessToken: String, user: User): User - fun logout(accessToken: String, user: User) - //fun removeToken(accessToken: String) + fun putUser(accessToken: String, user: User): User + fun logout(accessToken: String, user: User) + //fun removeToken(accessToken: String) - companion object { - const val AUTH_COOKIE = "sessionId" - } + companion object { + const val AUTH_COOKIE = "sessionId" + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthorizationInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthorizationInterface.kt index 2c28a419..9a01202e 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthorizationInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/AuthorizationInterface.kt @@ -1,20 +1,20 @@ package com.simiacryptus.skyenet.core.platform.model interface AuthorizationInterface { - enum class OperationType { - Read, - Write, - Public, - Share, - Execute, - Delete, - Admin, - GlobalKey, - } + enum class OperationType { + Read, + Write, + Public, + Share, + Execute, + Delete, + Admin, + GlobalKey, + } - fun isAuthorized( - applicationClass: Class<*>?, - user: User?, - operationType: OperationType, - ): Boolean + fun isAuthorized( + applicationClass: Class<*>?, + user: User?, + operationType: OperationType, + ): Boolean } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/CloudPlatformInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/CloudPlatformInterface.kt index 2d4f05ef..bdfe9f2b 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/CloudPlatformInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/CloudPlatformInterface.kt @@ -1,20 +1,20 @@ package com.simiacryptus.skyenet.core.platform.model interface CloudPlatformInterface { - val shareBase: String + val shareBase: String - fun upload( - path: String, - contentType: String, - bytes: ByteArray - ): String + fun upload( + path: String, + contentType: String, + bytes: ByteArray + ): String - fun upload( - path: String, - contentType: String, - request: String - ): String + fun upload( + path: String, + contentType: String, + request: String + ): String - fun encrypt(fileBytes: ByteArray, keyId: String): String? - fun decrypt(encryptedData: ByteArray): String + fun encrypt(fileBytes: ByteArray, keyId: String): String? + fun decrypt(encryptedData: ByteArray): String } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/MetadataStorageInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/MetadataStorageInterface.kt index a76d02ce..319337be 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/MetadataStorageInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/MetadataStorageInterface.kt @@ -4,12 +4,12 @@ import com.simiacryptus.skyenet.core.platform.Session import java.util.Date interface MetadataStorageInterface { - fun getSessionName(user: User?, session: Session): String - fun setSessionName(user: User?, session: Session, name: String) - fun getMessageIds(user: User?, session: Session): List - fun setMessageIds(user: User?, session: Session, ids: List) - fun getSessionTime(user: User?, session: Session): Date? - fun setSessionTime(user: User?, session: Session, time: Date) - fun listSessions(path: String): List - fun deleteSession(user: User?, session: Session) + fun getSessionName(user: User?, session: Session): String + fun setSessionName(user: User?, session: Session, name: String) + fun getMessageIds(user: User?, session: Session): List + fun setMessageIds(user: User?, session: Session, ids: List) + fun getSessionTime(user: User?, session: Session): Date? + fun setSessionTime(user: User?, session: Session, time: Date) + fun listSessions(path: String): List + fun deleteSession(user: User?, session: Session) } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/StorageInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/StorageInterface.kt index 0bf58723..15260344 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/StorageInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/StorageInterface.kt @@ -7,91 +7,91 @@ import java.util.LinkedHashMap interface StorageInterface { - fun getMessages( - user: User?, - session: Session - ): LinkedHashMap - - fun getSessionDir( - user: User?, - session: Session - ): File - - fun getDataDir( - user: User?, - session: Session - ): File - - @Deprecated("Use metadataStorage instead") - fun getSessionName( - user: User?, - session: Session - ): String - - @Deprecated("Use metadataStorage instead") - fun getSessionTime( - user: User?, - session: Session - ): Date? - - fun listSessions( - user: User?, - path: String, - ): List - - fun setJson( - user: User?, - session: Session, - filename: String, - settings: T - ): T - - fun updateMessage( - user: User?, - session: Session, - messageId: String, - value: String - ) - - @Deprecated("Use metadataStorage instead") - fun listSessions(dir: File, path: String): List - fun userRoot(user: User?): File - fun deleteSession(user: User?, session: Session) - - @Deprecated("Use metadataStorage instead") - fun getMessageIds( - user: User?, - session: Session - ): List - - @Deprecated("Use metadataStorage instead") - fun setMessageIds( - user: User?, - session: Session, - ids: List - ) - - companion object { - @Deprecated("Use Session.long64() instead", ReplaceWith("Session.long64()")) - inline fun long64() = Session.long64() - - @Deprecated("Use Session.validateSessionId(session) instead", ReplaceWith("Session.validateSessionId(session)")) - inline fun validateSessionId(session: Session) = Session.validateSessionId(session) - - @Deprecated("Use Session.newGlobalID() instead", ReplaceWith("Session.newGlobalID()")) - inline fun newGlobalID(): Session = Session.newGlobalID() - - @Deprecated("Use Session.newUserID() instead", ReplaceWith("Session.newUserID()")) - inline fun newUserID(): Session = Session.newUserID() - - - @Deprecated("Use Session.parseSessionID(sessionID) instead", ReplaceWith("Session.parseSessionID(sessionID)")) - inline fun parseSessionID(sessionID: String): Session = Session.parseSessionID(sessionID) - - @Deprecated("Use Session.id2() instead") - private inline fun id2() = Session.long64().filter { - it in 'a'..'z' || it in 'A'..'Z' || it in '0'..'9' - }.take(4) - - } + fun getMessages( + user: User?, + session: Session + ): LinkedHashMap + + fun getSessionDir( + user: User?, + session: Session + ): File + + fun getDataDir( + user: User?, + session: Session + ): File + + @Deprecated("Use metadataStorage instead") + fun getSessionName( + user: User?, + session: Session + ): String + + @Deprecated("Use metadataStorage instead") + fun getSessionTime( + user: User?, + session: Session + ): Date? + + fun listSessions( + user: User?, + path: String, + ): List + + fun setJson( + user: User?, + session: Session, + filename: String, + settings: T + ): T + + fun updateMessage( + user: User?, + session: Session, + messageId: String, + value: String + ) + + @Deprecated("Use metadataStorage instead") + fun listSessions(dir: File, path: String): List + fun userRoot(user: User?): File + fun deleteSession(user: User?, session: Session) + + @Deprecated("Use metadataStorage instead") + fun getMessageIds( + user: User?, + session: Session + ): List + + @Deprecated("Use metadataStorage instead") + fun setMessageIds( + user: User?, + session: Session, + ids: List + ) + + companion object { + @Deprecated("Use Session.long64() instead", ReplaceWith("Session.long64()")) + inline fun long64() = Session.long64() + + @Deprecated("Use Session.validateSessionId(session) instead", ReplaceWith("Session.validateSessionId(session)")) + inline fun validateSessionId(session: Session) = Session.validateSessionId(session) + + @Deprecated("Use Session.newGlobalID() instead", ReplaceWith("Session.newGlobalID()")) + inline fun newGlobalID(): Session = Session.newGlobalID() + + @Deprecated("Use Session.newUserID() instead", ReplaceWith("Session.newUserID()")) + inline fun newUserID(): Session = Session.newUserID() + + + @Deprecated("Use Session.parseSessionID(sessionID) instead", ReplaceWith("Session.parseSessionID(sessionID)")) + inline fun parseSessionID(sessionID: String): Session = Session.parseSessionID(sessionID) + + @Deprecated("Use Session.id2() instead") + private inline fun id2() = Session.long64().filter { + it in 'a'..'z' || it in 'A'..'Z' || it in '0'..'9' + }.take(4) + + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UsageInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UsageInterface.kt index dafae9bc..d384b8d7 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UsageInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UsageInterface.kt @@ -10,56 +10,56 @@ import com.simiacryptus.skyenet.core.platform.Session import java.util.concurrent.atomic.AtomicLong interface UsageInterface { - fun incrementUsage(session: Session, user: User?, model: OpenAIModel, tokens: ApiModel.Usage) = incrementUsage( - session, when (user) { - null -> null - else -> { - val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) - userSettings.apiKeys[if (model is ChatModel) { - model.provider - } else { - APIProvider.Companion.OpenAI - }] - } - }, model, tokens - ) + fun incrementUsage(session: Session, user: User?, model: OpenAIModel, tokens: ApiModel.Usage) = incrementUsage( + session, when (user) { + null -> null + else -> { + val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user) + userSettings.apiKeys[if (model is ChatModel) { + model.provider + } else { + APIProvider.Companion.OpenAI + }] + } + }, model, tokens + ) - fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) + fun incrementUsage(session: Session, apiKey: String?, model: OpenAIModel, tokens: ApiModel.Usage) - fun getUserUsageSummary(user: User): Map = getUserUsageSummary( - ApplicationServices.userSettingsManager.getUserSettings(user).apiKeys[APIProvider.Companion.OpenAI]!! // TODO: Support other providers - ) + fun getUserUsageSummary(user: User): Map = getUserUsageSummary( + ApplicationServices.userSettingsManager.getUserSettings(user).apiKeys[APIProvider.Companion.OpenAI]!! // TODO: Support other providers + ) - fun getUserUsageSummary(apiKey: String): Map + fun getUserUsageSummary(apiKey: String): Map - fun getSessionUsageSummary(session: Session): Map - fun clear() + fun getSessionUsageSummary(session: Session): Map + fun clear() - data class UsageKey( - val session: Session, - val apiKey: String?, - val model: OpenAIModel, - ) + data class UsageKey( + val session: Session, + val apiKey: String?, + val model: OpenAIModel, + ) - class UsageValues( - val inputTokens: AtomicLong = AtomicLong(), - val outputTokens: AtomicLong = AtomicLong(), - val cost: AtomicDouble = AtomicDouble(), - ) { - fun addAndGet(tokens: ApiModel.Usage) { - inputTokens.addAndGet(tokens.prompt_tokens) - outputTokens.addAndGet(tokens.completion_tokens) - cost.addAndGet(tokens.cost ?: 0.0) - } - - fun toUsage() = ApiModel.Usage( - prompt_tokens = inputTokens.get(), - completion_tokens = outputTokens.get(), - cost = cost.get() - ) + class UsageValues( + val inputTokens: AtomicLong = AtomicLong(), + val outputTokens: AtomicLong = AtomicLong(), + val cost: AtomicDouble = AtomicDouble(), + ) { + fun addAndGet(tokens: ApiModel.Usage) { + inputTokens.addAndGet(tokens.prompt_tokens) + outputTokens.addAndGet(tokens.completion_tokens) + cost.addAndGet(tokens.cost ?: 0.0) } - class UsageCounters( - val tokensPerModel: java.util.HashMap = HashMap(), + fun toUsage() = ApiModel.Usage( + prompt_tokens = inputTokens.get(), + completion_tokens = outputTokens.get(), + cost = cost.get() ) + } + + class UsageCounters( + val tokensPerModel: java.util.HashMap = HashMap(), + ) } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/User.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/User.kt index f0ba412a..d8a04aa6 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/User.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/User.kt @@ -4,25 +4,25 @@ import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.annotation.JsonProperty data class User( - @get:JsonProperty("email") val email: String, - @get:JsonProperty("name") val name: String? = null, - @get:JsonProperty("id") val id: String? = null, - @get:JsonProperty("picture") val picture: String? = null, - @get:JsonIgnore val credential: Any? = null, + @get:JsonProperty("email") val email: String, + @get:JsonProperty("name") val name: String? = null, + @get:JsonProperty("id") val id: String? = null, + @get:JsonProperty("picture") val picture: String? = null, + @get:JsonIgnore val credential: Any? = null, ) { - override fun toString() = email + override fun toString() = email - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false - other as User + other as User - return email == other.email - } + return email == other.email + } - override fun hashCode(): Int { - return email.hashCode() - } + override fun hashCode(): Int { + return email.hashCode() + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UserSettingsInterface.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UserSettingsInterface.kt index f49d886c..cef5c03a 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UserSettingsInterface.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/model/UserSettingsInterface.kt @@ -3,12 +3,12 @@ package com.simiacryptus.skyenet.core.platform.model import com.simiacryptus.jopenai.models.APIProvider interface UserSettingsInterface { - data class UserSettings( - val apiKeys: Map = APIProvider.Companion.values().associateWith { "" }, - val apiBase: Map = APIProvider.Companion.values().associateWith { it.base ?: "" }, - ) + data class UserSettings( + val apiKeys: Map = APIProvider.Companion.values().associateWith { "" }, + val apiBase: Map = APIProvider.Companion.values().associateWith { it.base ?: "" }, + ) - fun getUserSettings(user: User): UserSettings + fun getUserSettings(user: User): UserSettings - fun updateUserSettings(user: User, settings: UserSettings) + fun updateUserSettings(user: User, settings: UserSettings) } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt index 6ef61fa9..2b9fb461 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthenticationInterfaceTest.kt @@ -5,44 +5,44 @@ import org.junit.jupiter.api.Test import java.util.* open class AuthenticationInterfaceTest( - private val authInterface: AuthenticationInterface + private val authInterface: AuthenticationInterface ) { - private val validAccessToken = UUID.randomUUID().toString() - private val newUser = User( - email = "newuser@example.com", - name = "Jane Smith", - id = "2", - picture = "http://example.com/newpicture.jpg" - ) - - @Test - fun `getUser should return null when no user is associated with access token`() { - val user = authInterface.getUser(validAccessToken) - assertNull(user) - } - - @Test - fun `putUser should add a new user and return the user`() { - val returnedUser = authInterface.putUser(validAccessToken, newUser) - assertEquals(newUser, returnedUser) - } - - @Test - fun `getUser should return User after putUser is called`() { - authInterface.putUser(validAccessToken, newUser) - val user: User? = authInterface.getUser(validAccessToken) - assertNotNull(user) - assertEquals(newUser, user) - } - - @Test - fun `logout should remove the user associated with the access token`() { - authInterface.putUser(validAccessToken, newUser) - assertNotNull(authInterface.getUser(validAccessToken)) - - authInterface.logout(validAccessToken, newUser) - assertNull(authInterface.getUser(validAccessToken)) - } + private val validAccessToken = UUID.randomUUID().toString() + private val newUser = User( + email = "newuser@example.com", + name = "Jane Smith", + id = "2", + picture = "http://example.com/newpicture.jpg" + ) + + @Test + fun `getUser should return null when no user is associated with access token`() { + val user = authInterface.getUser(validAccessToken) + assertNull(user) + } + + @Test + fun `putUser should add a new user and return the user`() { + val returnedUser = authInterface.putUser(validAccessToken, newUser) + assertEquals(newUser, returnedUser) + } + + @Test + fun `getUser should return User after putUser is called`() { + authInterface.putUser(validAccessToken, newUser) + val user: User? = authInterface.getUser(validAccessToken) + assertNotNull(user) + assertEquals(newUser, user) + } + + @Test + fun `logout should remove the user associated with the access token`() { + authInterface.putUser(validAccessToken, newUser) + assertNotNull(authInterface.getUser(validAccessToken)) + + authInterface.logout(validAccessToken, newUser) + assertNull(authInterface.getUser(validAccessToken)) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt index b378aa1c..4dfb2177 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/AuthorizationInterfaceTest.kt @@ -4,19 +4,19 @@ import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Test open class AuthorizationInterfaceTest( - private val authInterface: AuthorizationInterface + private val authInterface: AuthorizationInterface ) { - open val user = User( - email = "newuser@example.com", - name = "Jane Smith", - id = "2", - picture = "http://example.com/newpicture.jpg" - ) + open val user = User( + email = "newuser@example.com", + name = "Jane Smith", + id = "2", + picture = "http://example.com/newpicture.jpg" + ) - @Test - fun `newUser has admin`() { - assertFalse(authInterface.isAuthorized(this.javaClass, user, AuthorizationInterface.OperationType.Admin)) - } + @Test + fun `newUser has admin`() { + assertFalse(authInterface.isAuthorized(this.javaClass, user, AuthorizationInterface.OperationType.Admin)) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/MetadataStorageInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/MetadataStorageInterfaceTest.kt index 58cf6414..ad54b1a6 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/MetadataStorageInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/MetadataStorageInterfaceTest.kt @@ -5,163 +5,163 @@ import com.simiacryptus.skyenet.core.platform.model.MetadataStorageInterface import com.simiacryptus.skyenet.core.platform.model.User import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test -import java.util.* import org.slf4j.LoggerFactory +import java.util.* abstract class MetadataStorageInterfaceTest(val storage: MetadataStorageInterface) { - companion object { - private val log = LoggerFactory.getLogger(MetadataStorageInterfaceTest::class.java) - } - - - @Test - fun testGetSessionName() { - log.info("Starting testGetSessionName") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Retrieving session name for user {} and session {}", user.email, session.sessionId) - val sessionName = storage.getSessionName(user, session) - - // Assert - log.debug("Retrieved session name: {}", sessionName) - assertNotNull(sessionName) - assertTrue(sessionName is String) - log.info("Completed testGetSessionName successfully") - } - - @Test - fun testSetSessionName() { - log.info("Starting testSetSessionName") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val newName = "Test Session" - - // Act - log.debug("Setting session name to '{}' for user {} and session {}", newName, user.email, session.sessionId) - storage.setSessionName(user, session, newName) - log.debug("Retrieving session name for verification") - val retrievedName = storage.getSessionName(user, session) - - // Assert - log.debug("Retrieved session name: {}", retrievedName) - assertEquals(newName, retrievedName) - log.info("Completed testSetSessionName successfully") - } - - @Test - fun testGetMessageIds() { - log.info("Starting testGetMessageIds") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Retrieving message IDs for user {} and session {}", user.email, session.sessionId) - val messageIds = storage.getMessageIds(user, session) - - // Assert - log.debug("Retrieved message IDs: {}", messageIds) - assertNotNull(messageIds) - assertTrue(messageIds is List<*>) - log.info("Completed testGetMessageIds successfully") - } - - @Test - fun testSetMessageIds() { - log.info("Starting testSetMessageIds") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val newIds = listOf("msg001", "msg002", "msg003") - - // Act - log.debug("Setting message IDs {} for user {} and session {}", newIds, user.email, session.sessionId) - storage.setMessageIds(user, session, newIds) - log.debug("Retrieving message IDs for verification") - val retrievedIds = storage.getMessageIds(user, session) - - // Assert - log.debug("Retrieved message IDs: {}", retrievedIds) - assertEquals(newIds, retrievedIds) - log.info("Completed testSetMessageIds successfully") - } - -// @Test - fun testGetSessionTime() { - log.info("Starting testGetSessionTime") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Retrieving session time for user {} and session {}", user.email, session.sessionId) - val sessionTime = storage.getSessionTime(user, session) - - // Assert - log.debug("Retrieved session time: {}", sessionTime) - assertNotNull(sessionTime) - assertTrue(sessionTime is Date) - log.info("Completed testGetSessionTime successfully") - } - - @Test - fun testSetSessionTime() { - log.info("Starting testSetSessionTime") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val newTime = Date() - - // Act - log.debug("Setting session time to {} for user {} and session {}", newTime, user.email, session.sessionId) - storage.setSessionTime(user, session, newTime) - log.debug("Retrieving session time for verification") - val retrievedTime = storage.getSessionTime(user, session) - - // Assert - log.debug("Retrieved session time: {}", retrievedTime) - assertEquals(newTime.toString(), retrievedTime.toString()) - log.info("Completed testSetSessionTime successfully") - } - - @Test - fun testListSessions() { - log.info("Starting testListSessions") - // Arrange - val path = "" - - // Act - log.debug("Listing sessions for path: {}", path) - val sessions = storage.listSessions(path) - - // Assert - log.debug("Retrieved sessions: {}", sessions) - assertNotNull(sessions) - assertTrue(sessions is List<*>) - log.info("Completed testListSessions successfully") - } - - @Test - fun testDeleteSession() { - log.info("Starting testDeleteSession") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act and Assert - try { - log.debug("Attempting to delete session {} for user {}", session.sessionId, user.email) - storage.deleteSession(user, session) - log.info("Session deleted successfully") - // If no exception is thrown, the test passes. - } catch (e: Exception) { - log.error("Failed to delete session: {}", e.message, e) - fail("Exception should not be thrown") - } - log.info("Completed testDeleteSession successfully") + companion object { + private val log = LoggerFactory.getLogger(MetadataStorageInterfaceTest::class.java) + } + + + @Test + fun testGetSessionName() { + log.info("Starting testGetSessionName") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Retrieving session name for user {} and session {}", user.email, session.sessionId) + val sessionName = storage.getSessionName(user, session) + + // Assert + log.debug("Retrieved session name: {}", sessionName) + assertNotNull(sessionName) + assertTrue(sessionName is String) + log.info("Completed testGetSessionName successfully") + } + + @Test + fun testSetSessionName() { + log.info("Starting testSetSessionName") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val newName = "Test Session" + + // Act + log.debug("Setting session name to '{}' for user {} and session {}", newName, user.email, session.sessionId) + storage.setSessionName(user, session, newName) + log.debug("Retrieving session name for verification") + val retrievedName = storage.getSessionName(user, session) + + // Assert + log.debug("Retrieved session name: {}", retrievedName) + assertEquals(newName, retrievedName) + log.info("Completed testSetSessionName successfully") + } + + @Test + fun testGetMessageIds() { + log.info("Starting testGetMessageIds") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Retrieving message IDs for user {} and session {}", user.email, session.sessionId) + val messageIds = storage.getMessageIds(user, session) + + // Assert + log.debug("Retrieved message IDs: {}", messageIds) + assertNotNull(messageIds) + assertTrue(messageIds is List<*>) + log.info("Completed testGetMessageIds successfully") + } + + @Test + fun testSetMessageIds() { + log.info("Starting testSetMessageIds") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val newIds = listOf("msg001", "msg002", "msg003") + + // Act + log.debug("Setting message IDs {} for user {} and session {}", newIds, user.email, session.sessionId) + storage.setMessageIds(user, session, newIds) + log.debug("Retrieving message IDs for verification") + val retrievedIds = storage.getMessageIds(user, session) + + // Assert + log.debug("Retrieved message IDs: {}", retrievedIds) + assertEquals(newIds, retrievedIds) + log.info("Completed testSetMessageIds successfully") + } + + // @Test + fun testGetSessionTime() { + log.info("Starting testGetSessionTime") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Retrieving session time for user {} and session {}", user.email, session.sessionId) + val sessionTime = storage.getSessionTime(user, session) + + // Assert + log.debug("Retrieved session time: {}", sessionTime) + assertNotNull(sessionTime) + assertTrue(sessionTime is Date) + log.info("Completed testGetSessionTime successfully") + } + + @Test + fun testSetSessionTime() { + log.info("Starting testSetSessionTime") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val newTime = Date() + + // Act + log.debug("Setting session time to {} for user {} and session {}", newTime, user.email, session.sessionId) + storage.setSessionTime(user, session, newTime) + log.debug("Retrieving session time for verification") + val retrievedTime = storage.getSessionTime(user, session) + + // Assert + log.debug("Retrieved session time: {}", retrievedTime) + assertEquals(newTime.toString(), retrievedTime.toString()) + log.info("Completed testSetSessionTime successfully") + } + + @Test + fun testListSessions() { + log.info("Starting testListSessions") + // Arrange + val path = "" + + // Act + log.debug("Listing sessions for path: {}", path) + val sessions = storage.listSessions(path) + + // Assert + log.debug("Retrieved sessions: {}", sessions) + assertNotNull(sessions) + assertTrue(sessions is List<*>) + log.info("Completed testListSessions successfully") + } + + @Test + fun testDeleteSession() { + log.info("Starting testDeleteSession") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act and Assert + try { + log.debug("Attempting to delete session {} for user {}", session.sessionId, user.email) + storage.deleteSession(user, session) + log.info("Session deleted successfully") + // If no exception is thrown, the test passes. + } catch (e: Exception) { + log.error("Failed to delete session: {}", e.message, e) + fail("Exception should not be thrown") } + log.info("Completed testDeleteSession successfully") + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/StorageInterfaceTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/StorageInterfaceTest.kt index 5465cb05..1b23da01 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/StorageInterfaceTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/StorageInterfaceTest.kt @@ -12,216 +12,216 @@ import java.io.File import java.util.* abstract class StorageInterfaceTest(val storage: StorageInterface) { - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(StorageInterfaceTest::class.java) + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(StorageInterfaceTest::class.java) + } + + + @Test + fun testGetJson() { + log.info("Starting testGetJson") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val filename = "test.json" + + // Act + log.debug("Attempting to read JSON file: {}", filename) + val settingsFile = File(storage.getSessionDir(user, session), filename) + val result = if (!settingsFile.exists()) null else { + JsonUtil.objectMapper().readValue(settingsFile, Any::class.java) as Any } - - @Test - fun testGetJson() { - log.info("Starting testGetJson") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val filename = "test.json" - - // Act - log.debug("Attempting to read JSON file: {}", filename) - val settingsFile = File(storage.getSessionDir(user, session), filename) - val result = if (!settingsFile.exists()) null else { - JsonUtil.objectMapper().readValue(settingsFile, Any::class.java) as Any - } - - // Assert - log.info("Asserting result is null for non-existing JSON file") - Assertions.assertNull(result, "Expected null result for non-existing JSON file") - log.info("testGetJson completed successfully") - } - - @Test - fun testGetMessages() { - log.info("Starting testGetMessages") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Retrieving messages for user: {} and session: {}", user.email, session.sessionId) - val messages = storage.getMessages(user, session) - - // Assert - log.info("Asserting messages type is LinkedHashMap") - assertTrue(messages is LinkedHashMap<*, *>, "Expected LinkedHashMap type for messages") - log.info("testGetMessages completed successfully") - } - - @Test - fun testGetSessionDir() { - log.info("Starting testGetSessionDir") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Getting session directory for user: {} and session: {}", user.email, session.sessionId) - val sessionDir = storage.getSessionDir(user, session) - - // Assert - log.info("Asserting session directory is of type File") - assertTrue(sessionDir is File, "Expected File type for session directory") - log.info("testGetSessionDir completed successfully") - } - - @Test - fun testGetSessionName() { - log.info("Starting testGetSessionName") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act - log.debug("Getting session name for user: {} and session: {}", user.email, session.sessionId) - val sessionName = storage.getSessionName(user, session) - - // Assert - log.info("Asserting session name is not null and is of type String") - Assertions.assertNotNull(sessionName) - assertTrue(sessionName is String) - log.info("testGetSessionName completed successfully") - } - - @Test - fun testGetSessionTime() { - log.info("Starting testGetSessionTime") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - log.debug("Updating message for user: {} and session: {}", user.email, session.sessionId) - storage.updateMessage(user, session, "msg001", "

Hello, World!

Hello, World!

") - - // Act - log.debug("Getting session time for user: {} and session: {}", user.email, session.sessionId) - val sessionTime = storage.getSessionTime(user, session) - - // Assert - log.info("Asserting session time is not null and is of type Date") - Assertions.assertNotNull(sessionTime) - assertTrue(sessionTime is Date) - log.info("testGetSessionTime completed successfully") - } - - @Test - fun testListSessions() { - log.info("Starting testListSessions") - // Arrange - val user = User(email = "test@example.com") - - // Act - log.debug("Listing sessions for user: {}", user.email) - val sessions = storage.listSessions(user, "") - - // Assert - log.info("Asserting sessions list is not null and is of type List") - Assertions.assertNotNull(sessions) - assertTrue(sessions is List<*>) - log.info("testListSessions completed successfully") - } - - @Test - fun testSetJson() { - log.info("Starting testSetJson") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val filename = "settings.json" - val settings = mapOf("theme" to "dark") - - // Act - log.debug("Setting JSON for user: {} and session: {}", user.email, session.sessionId) - val result = storage.setJson(user, session, filename, settings) - - // Assert - log.info("Asserting JSON setting result is not null and matches input") - Assertions.assertNotNull(result) - assertEquals(settings, result) - log.info("testSetJson completed successfully") - } - - @Test - fun testUpdateMessage() { - log.info("Starting testUpdateMessage") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - val messageId = "msg001" - val value = "Hello, World!" - - // Act and Assert - try { - log.debug("Updating message for user: {} and session: {}", user.email, session.sessionId) - storage.updateMessage(user, session, messageId, value) - log.info("Message updated successfully") - // If no exception is thrown, the test passes. - } catch (e: Exception) { - log.error("Exception thrown while updating message", e) - Assertions.fail("Exception should not be thrown") - } - log.info("testUpdateMessage completed successfully") + // Assert + log.info("Asserting result is null for non-existing JSON file") + Assertions.assertNull(result, "Expected null result for non-existing JSON file") + log.info("testGetJson completed successfully") + } + + @Test + fun testGetMessages() { + log.info("Starting testGetMessages") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Retrieving messages for user: {} and session: {}", user.email, session.sessionId) + val messages = storage.getMessages(user, session) + + // Assert + log.info("Asserting messages type is LinkedHashMap") + assertTrue(messages is LinkedHashMap<*, *>, "Expected LinkedHashMap type for messages") + log.info("testGetMessages completed successfully") + } + + @Test + fun testGetSessionDir() { + log.info("Starting testGetSessionDir") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Getting session directory for user: {} and session: {}", user.email, session.sessionId) + val sessionDir = storage.getSessionDir(user, session) + + // Assert + log.info("Asserting session directory is of type File") + assertTrue(sessionDir is File, "Expected File type for session directory") + log.info("testGetSessionDir completed successfully") + } + + @Test + fun testGetSessionName() { + log.info("Starting testGetSessionName") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act + log.debug("Getting session name for user: {} and session: {}", user.email, session.sessionId) + val sessionName = storage.getSessionName(user, session) + + // Assert + log.info("Asserting session name is not null and is of type String") + Assertions.assertNotNull(sessionName) + assertTrue(sessionName is String) + log.info("testGetSessionName completed successfully") + } + + @Test + fun testGetSessionTime() { + log.info("Starting testGetSessionTime") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + log.debug("Updating message for user: {} and session: {}", user.email, session.sessionId) + storage.updateMessage(user, session, "msg001", "

Hello, World!

Hello, World!

") + + // Act + log.debug("Getting session time for user: {} and session: {}", user.email, session.sessionId) + val sessionTime = storage.getSessionTime(user, session) + + // Assert + log.info("Asserting session time is not null and is of type Date") + Assertions.assertNotNull(sessionTime) + assertTrue(sessionTime is Date) + log.info("testGetSessionTime completed successfully") + } + + @Test + fun testListSessions() { + log.info("Starting testListSessions") + // Arrange + val user = User(email = "test@example.com") + + // Act + log.debug("Listing sessions for user: {}", user.email) + val sessions = storage.listSessions(user, "") + + // Assert + log.info("Asserting sessions list is not null and is of type List") + Assertions.assertNotNull(sessions) + assertTrue(sessions is List<*>) + log.info("testListSessions completed successfully") + } + + @Test + fun testSetJson() { + log.info("Starting testSetJson") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val filename = "settings.json" + val settings = mapOf("theme" to "dark") + + // Act + log.debug("Setting JSON for user: {} and session: {}", user.email, session.sessionId) + val result = storage.setJson(user, session, filename, settings) + + // Assert + log.info("Asserting JSON setting result is not null and matches input") + Assertions.assertNotNull(result) + assertEquals(settings, result) + log.info("testSetJson completed successfully") + } + + @Test + fun testUpdateMessage() { + log.info("Starting testUpdateMessage") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + val messageId = "msg001" + val value = "Hello, World!" + + // Act and Assert + try { + log.debug("Updating message for user: {} and session: {}", user.email, session.sessionId) + storage.updateMessage(user, session, messageId, value) + log.info("Message updated successfully") + // If no exception is thrown, the test passes. + } catch (e: Exception) { + log.error("Exception thrown while updating message", e) + Assertions.fail("Exception should not be thrown") } - - @Test - fun testListSessionsWithDir() { - log.info("Starting testListSessionsWithDir") - // Arrange - val directory = File(System.getProperty("user.dir")) // Example directory - - // Act - log.debug("Listing sessions for directory: {}", directory.absolutePath) - val sessionList = storage.listSessions(directory, "") - - // Assert - log.info("Asserting session list is not null and is of type List") - Assertions.assertNotNull(sessionList) - assertTrue(sessionList is List<*>) - log.info("testListSessionsWithDir completed successfully") - } - - @Test - fun testUserRoot() { - log.info("Starting testUserRoot") - // Arrange - val user = User(email = "test@example.com") - - // Act - log.debug("Getting user root for user: {}", user.email) - val userRoot = storage.userRoot(user) - - // Assert - log.info("Asserting user root is not null and is of type File") - Assertions.assertNotNull(userRoot) - assertTrue(userRoot is File) - log.info("testUserRoot completed successfully") - } - - @Test - fun testDeleteSession() { - log.info("Starting testDeleteSession") - // Arrange - val user = User(email = "test@example.com") - val session = Session("G-20230101-1234") - - // Act and Assert - try { - log.debug("Deleting session for user: {} and session: {}", user.email, session.sessionId) - storage.deleteSession(user, session) - log.info("Session deleted successfully") - // If no exception is thrown, the test passes. - } catch (e: Exception) { - log.error("Exception thrown while deleting session", e) - Assertions.fail("Exception should not be thrown") - } - log.info("testDeleteSession completed successfully") + log.info("testUpdateMessage completed successfully") + } + + @Test + fun testListSessionsWithDir() { + log.info("Starting testListSessionsWithDir") + // Arrange + val directory = File(System.getProperty("user.dir")) // Example directory + + // Act + log.debug("Listing sessions for directory: {}", directory.absolutePath) + val sessionList = storage.listSessions(directory, "") + + // Assert + log.info("Asserting session list is not null and is of type List") + Assertions.assertNotNull(sessionList) + assertTrue(sessionList is List<*>) + log.info("testListSessionsWithDir completed successfully") + } + + @Test + fun testUserRoot() { + log.info("Starting testUserRoot") + // Arrange + val user = User(email = "test@example.com") + + // Act + log.debug("Getting user root for user: {}", user.email) + val userRoot = storage.userRoot(user) + + // Assert + log.info("Asserting user root is not null and is of type File") + Assertions.assertNotNull(userRoot) + assertTrue(userRoot is File) + log.info("testUserRoot completed successfully") + } + + @Test + fun testDeleteSession() { + log.info("Starting testDeleteSession") + // Arrange + val user = User(email = "test@example.com") + val session = Session("G-20230101-1234") + + // Act and Assert + try { + log.debug("Deleting session for user: {} and session: {}", user.email, session.sessionId) + storage.deleteSession(user, session) + log.info("Session deleted successfully") + // If no exception is thrown, the test passes. + } catch (e: Exception) { + log.error("Exception thrown while deleting session", e) + Assertions.fail("Exception should not be thrown") } - // Continue writing tests for each method in StorageInterface... - // ... + log.info("testDeleteSession completed successfully") + } + // Continue writing tests for each method in StorageInterface... + // ... } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt index af5b3e29..61b4df8f 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UsageTest.kt @@ -11,153 +11,153 @@ import org.junit.jupiter.api.Test import kotlin.random.Random abstract class UsageTest(private val impl: UsageInterface) { - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(UsageTest::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(UsageTest::class.java) + } - private val testUser = User( - email = "test@example.com", - name = "Test User", - id = Random.nextInt().toString() - ) + private val testUser = User( + email = "test@example.com", + name = "Test User", + id = Random.nextInt().toString() + ) - @BeforeEach - fun setup() { - log.info("Setting up UsageTest: Clearing all usage data") - impl.clear() - } + @BeforeEach + fun setup() { + log.info("Setting up UsageTest: Clearing all usage data") + impl.clear() + } - @Test - fun `incrementUsage should increment usage for session`() { - log.debug("Starting test: incrementUsage should increment usage for session") - val model = OpenAIModels.GPT4oMini - val session = Session.newGlobalID() - val usage = ApiModel.Usage( - prompt_tokens = 10, - completion_tokens = 20, - cost = 30.0, - ) - log.info("Incrementing usage for session {} with model {}", session, model) - impl.incrementUsage(session, testUser, model, usage) - val usageSummary = impl.getSessionUsageSummary(session) - Assertions.assertEquals(usage, usageSummary[model]) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertEquals(usage, userUsageSummary[model]) - } + @Test + fun `incrementUsage should increment usage for session`() { + log.debug("Starting test: incrementUsage should increment usage for session") + val model = OpenAIModels.GPT4oMini + val session = Session.newGlobalID() + val usage = ApiModel.Usage( + prompt_tokens = 10, + completion_tokens = 20, + cost = 30.0, + ) + log.info("Incrementing usage for session {} with model {}", session, model) + impl.incrementUsage(session, testUser, model, usage) + val usageSummary = impl.getSessionUsageSummary(session) + Assertions.assertEquals(usage, usageSummary[model]) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertEquals(usage, userUsageSummary[model]) + } - @Test - fun `getUserUsageSummary should return correct usage summary`() { - log.debug("Starting test: getUserUsageSummary should return correct usage summary") - val model = OpenAIModels.GPT4oMini - val session = Session.newGlobalID() - val usage = ApiModel.Usage( - prompt_tokens = 15, - completion_tokens = 25, - cost = 35.0, - ) - log.info("Incrementing usage for user {} with model {}", testUser.email, model) - impl.incrementUsage(session, testUser, model, usage) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertEquals(usage, userUsageSummary[model]) - } + @Test + fun `getUserUsageSummary should return correct usage summary`() { + log.debug("Starting test: getUserUsageSummary should return correct usage summary") + val model = OpenAIModels.GPT4oMini + val session = Session.newGlobalID() + val usage = ApiModel.Usage( + prompt_tokens = 15, + completion_tokens = 25, + cost = 35.0, + ) + log.info("Incrementing usage for user {} with model {}", testUser.email, model) + impl.incrementUsage(session, testUser, model, usage) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertEquals(usage, userUsageSummary[model]) + } - @Test - fun `clear should reset all usage data`() { - log.debug("Starting test: clear should reset all usage data") - val model = OpenAIModels.GPT4oMini - val session = Session.newGlobalID() - val usage = ApiModel.Usage( - prompt_tokens = 20, - completion_tokens = 30, - cost = 40.0, - ) - log.info("Incrementing usage before clearing") - impl.incrementUsage(session, testUser, model, usage) - log.info("Clearing all usage data") - impl.clear() - val usageSummary = impl.getSessionUsageSummary(session) - Assertions.assertTrue(usageSummary.isEmpty()) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertTrue(userUsageSummary.isEmpty()) - } + @Test + fun `clear should reset all usage data`() { + log.debug("Starting test: clear should reset all usage data") + val model = OpenAIModels.GPT4oMini + val session = Session.newGlobalID() + val usage = ApiModel.Usage( + prompt_tokens = 20, + completion_tokens = 30, + cost = 40.0, + ) + log.info("Incrementing usage before clearing") + impl.incrementUsage(session, testUser, model, usage) + log.info("Clearing all usage data") + impl.clear() + val usageSummary = impl.getSessionUsageSummary(session) + Assertions.assertTrue(usageSummary.isEmpty()) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertTrue(userUsageSummary.isEmpty()) + } - @Test - fun `incrementUsage should handle multiple models correctly`() { - log.debug("Starting test: incrementUsage should handle multiple models correctly") - val model1 = OpenAIModels.GPT4oMini - val model2 = OpenAIModels.GPT4Turbo - val session = Session.newGlobalID() - val usage1 = ApiModel.Usage( - prompt_tokens = 10, - completion_tokens = 20, - cost = 30.0, - ) - val usage2 = ApiModel.Usage( - prompt_tokens = 5, - completion_tokens = 10, - cost = 15.0, - ) - log.info("Incrementing usage for model1 {} and model2 {}", model1, model2) - impl.incrementUsage(session, testUser, model1, usage1) - impl.incrementUsage(session, testUser, model2, usage2) - log.debug("Verifying usage summaries for session and user") - val usageSummary = impl.getSessionUsageSummary(session) - Assertions.assertEquals(usage1, usageSummary[model1]) - Assertions.assertEquals(usage2, usageSummary[model2]) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertEquals(usage1, userUsageSummary[model1]) - Assertions.assertEquals(usage2, userUsageSummary[model2]) - } + @Test + fun `incrementUsage should handle multiple models correctly`() { + log.debug("Starting test: incrementUsage should handle multiple models correctly") + val model1 = OpenAIModels.GPT4oMini + val model2 = OpenAIModels.GPT4Turbo + val session = Session.newGlobalID() + val usage1 = ApiModel.Usage( + prompt_tokens = 10, + completion_tokens = 20, + cost = 30.0, + ) + val usage2 = ApiModel.Usage( + prompt_tokens = 5, + completion_tokens = 10, + cost = 15.0, + ) + log.info("Incrementing usage for model1 {} and model2 {}", model1, model2) + impl.incrementUsage(session, testUser, model1, usage1) + impl.incrementUsage(session, testUser, model2, usage2) + log.debug("Verifying usage summaries for session and user") + val usageSummary = impl.getSessionUsageSummary(session) + Assertions.assertEquals(usage1, usageSummary[model1]) + Assertions.assertEquals(usage2, usageSummary[model2]) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertEquals(usage1, userUsageSummary[model1]) + Assertions.assertEquals(usage2, userUsageSummary[model2]) + } - @Test - fun `incrementUsage should accumulate usage for the same model`() { - log.debug("Starting test: incrementUsage should accumulate usage for the same model") - val model = OpenAIModels.GPT4oMini - val session = Session.newGlobalID() - val usage1 = ApiModel.Usage( - prompt_tokens = 10, - completion_tokens = 20, - cost = 30.0, - ) - val usage2 = ApiModel.Usage( - prompt_tokens = 5, - completion_tokens = 10, - cost = 15.0, - ) - log.info("Incrementing usage twice for model {}", model) - impl.incrementUsage(session, testUser, model, usage1) - impl.incrementUsage(session, testUser, model, usage2) - log.debug("Verifying accumulated usage") - val usageSummary = impl.getSessionUsageSummary(session) - val expectedUsage = ApiModel.Usage( - prompt_tokens = 15, - completion_tokens = 30, - cost = 45.0, - ) - Assertions.assertEquals(expectedUsage, usageSummary[model]) - val userUsageSummary = impl.getUserUsageSummary(testUser) - Assertions.assertEquals(expectedUsage, userUsageSummary[model]) - } + @Test + fun `incrementUsage should accumulate usage for the same model`() { + log.debug("Starting test: incrementUsage should accumulate usage for the same model") + val model = OpenAIModels.GPT4oMini + val session = Session.newGlobalID() + val usage1 = ApiModel.Usage( + prompt_tokens = 10, + completion_tokens = 20, + cost = 30.0, + ) + val usage2 = ApiModel.Usage( + prompt_tokens = 5, + completion_tokens = 10, + cost = 15.0, + ) + log.info("Incrementing usage twice for model {}", model) + impl.incrementUsage(session, testUser, model, usage1) + impl.incrementUsage(session, testUser, model, usage2) + log.debug("Verifying accumulated usage") + val usageSummary = impl.getSessionUsageSummary(session) + val expectedUsage = ApiModel.Usage( + prompt_tokens = 15, + completion_tokens = 30, + cost = 45.0, + ) + Assertions.assertEquals(expectedUsage, usageSummary[model]) + val userUsageSummary = impl.getUserUsageSummary(testUser) + Assertions.assertEquals(expectedUsage, userUsageSummary[model]) + } - @Test - fun `getSessionUsageSummary should return empty map for unknown session`() { - log.debug("Starting test: getSessionUsageSummary should return empty map for unknown session") - val session = Session.newGlobalID() - log.info("Retrieving usage summary for unknown session {}", session) - val usageSummary = impl.getSessionUsageSummary(session) - Assertions.assertTrue(usageSummary.isEmpty()) - } + @Test + fun `getSessionUsageSummary should return empty map for unknown session`() { + log.debug("Starting test: getSessionUsageSummary should return empty map for unknown session") + val session = Session.newGlobalID() + log.info("Retrieving usage summary for unknown session {}", session) + val usageSummary = impl.getSessionUsageSummary(session) + Assertions.assertTrue(usageSummary.isEmpty()) + } - @Test - fun `getUserUsageSummary should return empty map for unknown user`() { - log.debug("Starting test: getUserUsageSummary should return empty map for unknown user") - val unknownUser = User( - email = "unknown@example.com", - name = "Unknown User", - id = Random.nextInt().toString() - ) - log.info("Retrieving usage summary for unknown user {}", unknownUser.email) - val userUsageSummary = impl.getUserUsageSummary(unknownUser) - Assertions.assertTrue(userUsageSummary.isEmpty()) - } + @Test + fun `getUserUsageSummary should return empty map for unknown user`() { + log.debug("Starting test: getUserUsageSummary should return empty map for unknown user") + val unknownUser = User( + email = "unknown@example.com", + name = "Unknown User", + id = Random.nextInt().toString() + ) + log.info("Retrieving usage summary for unknown user {}", unknownUser.email) + val userUsageSummary = impl.getUserUsageSummary(unknownUser) + Assertions.assertTrue(userUsageSummary.isEmpty()) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt index 9227c22d..b59b5921 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/platform/test/UserSettingsTest.kt @@ -8,56 +8,56 @@ import org.junit.jupiter.api.Test import java.util.* abstract class UserSettingsTest(private val userSettings: UserSettingsInterface) { - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(UserSettingsTest::class.java) - } - - - @Test - fun `updateUserSettings should store custom settings for user`() { - log.info("Starting test: updateUserSettings should store custom settings for user") - val id = UUID.randomUUID().toString() - val testUser = User( - email = "$id@example.com", - name = "Test User", - id = id - ) - log.debug("Created test user with id: {}", id) - - val newSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "12345")) - log.debug("Updating user settings with new API key") - userSettings.updateUserSettings(testUser, newSettings) - - val settings = userSettings.getUserSettings(testUser) - log.debug("Retrieved user settings after update") - - Assertions.assertEquals("12345", settings.apiKeys[APIProvider.OpenAI]) - log.info("Test completed: updateUserSettings successfully stored custom settings for user") - } - - @Test - fun `getUserSettings should return updated settings after updateUserSettings is called`() { - log.info("Starting test: getUserSettings should return updated settings after updateUserSettings is called") - val id = UUID.randomUUID().toString() - val testUser = User( - email = "$id@example.com", - name = "Test User", - id = id - ) - log.debug("Created test user with id: {}", id) - - val initialSettings = userSettings.getUserSettings(testUser) - log.debug("Retrieved initial user settings") - Assertions.assertEquals("", initialSettings.apiKeys[APIProvider.OpenAI]) - - val updatedSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "67890")) - log.debug("Updating user settings with new API key") - userSettings.updateUserSettings(testUser, updatedSettings) - - val settingsAfterUpdate = userSettings.getUserSettings(testUser) - log.debug("Retrieved user settings after update") - - Assertions.assertEquals("67890", settingsAfterUpdate.apiKeys[APIProvider.OpenAI]) - log.info("Test completed: getUserSettings successfully returned updated settings after updateUserSettings was called") - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(UserSettingsTest::class.java) + } + + + @Test + fun `updateUserSettings should store custom settings for user`() { + log.info("Starting test: updateUserSettings should store custom settings for user") + val id = UUID.randomUUID().toString() + val testUser = User( + email = "$id@example.com", + name = "Test User", + id = id + ) + log.debug("Created test user with id: {}", id) + + val newSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "12345")) + log.debug("Updating user settings with new API key") + userSettings.updateUserSettings(testUser, newSettings) + + val settings = userSettings.getUserSettings(testUser) + log.debug("Retrieved user settings after update") + + Assertions.assertEquals("12345", settings.apiKeys[APIProvider.OpenAI]) + log.info("Test completed: updateUserSettings successfully stored custom settings for user") + } + + @Test + fun `getUserSettings should return updated settings after updateUserSettings is called`() { + log.info("Starting test: getUserSettings should return updated settings after updateUserSettings is called") + val id = UUID.randomUUID().toString() + val testUser = User( + email = "$id@example.com", + name = "Test User", + id = id + ) + log.debug("Created test user with id: {}", id) + + val initialSettings = userSettings.getUserSettings(testUser) + log.debug("Retrieved initial user settings") + Assertions.assertEquals("", initialSettings.apiKeys[APIProvider.OpenAI]) + + val updatedSettings = UserSettingsInterface.UserSettings(apiKeys = mapOf(APIProvider.OpenAI to "67890")) + log.debug("Updating user settings with new API key") + userSettings.updateUserSettings(testUser, updatedSettings) + + val settingsAfterUpdate = userSettings.getUserSettings(testUser) + log.debug("Retrieved user settings after update") + + Assertions.assertEquals("67890", settingsAfterUpdate.apiKeys[APIProvider.OpenAI]) + log.info("Test completed: getUserSettings successfully returned updated settings after updateUserSettings was called") + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt index 3c8776ac..a0116bfe 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/ClasspathRelationships.kt @@ -6,368 +6,368 @@ import org.objectweb.asm.signature.SignatureVisitor import java.util.jar.JarFile object ClasspathRelationships { - sealed class Relation { - open val from_method: String = "" - open val to_method: String = "" + sealed class Relation { + open val from_method: String = "" + open val to_method: String = "" + } + + data object INHERITANCE : Relation() // When a class extends another class + data object INTERFACE_IMPLEMENTATION : Relation() // When a class implements an interface + data class FIELD_TYPE(override val from_method: String) : + Relation() // When a class has a field of another class type + + data class METHOD_PARAMETER( + override val from_method: String + ) : Relation() // When a class has a method that takes another class as a parameter + + data object METHOD_RETURN_TYPE : Relation() // When a class has a method that returns another class + data class LOCAL_VARIABLE(override val from_method: String) : + Relation() // When a method within a class declares a local variable of another class + + data class EXCEPTION_TYPE(override val from_method: String) : + Relation() // When a method declares that it throws an exception of another class + + data class ANNOTATION(override val from_method: String) : + Relation() // When a class, method, or field is annotated with another class (annotation) + + data class INSTANCE_CREATION(override val from_method: String) : + Relation() // When a class creates an instance of another class + + data class METHOD_REFERENCE( + override val from_method: String, + override val to_method: String + ) : Relation() // When a method references another class's method + + data class METHOD_SIGNATURE( + override val from_method: String, + override val to_method: String + ) : Relation() // When a method signature references another class + + data class FIELD_REFERENCE(override val from_method: String) : + Relation() // When a method references another class's field + + data class DYNAMIC_BINDING(override val from_method: String) : + Relation() // When a class uses dynamic binding (e.g., invoke dynamic) related to another class + + + class DependencyClassVisitor( + val dependencies: MutableMap> = mutableMapOf(), + var access: Int = 0, + var methods: MutableMap = mutableMapOf(), + ) : ClassVisitor(Opcodes.ASM9) { + + override fun visit( + version: Int, + access: Int, + name: String, + signature: String?, + superName: String?, + interfaces: Array? + ) { + this.access = access + // Add superclass dependency + superName?.let { addDep(it, INHERITANCE) } + // Add interface dependencies + interfaces?.forEach { addDep(it, INTERFACE_IMPLEMENTATION) } + visitSignature(name, signature) + super.visit(version, access, name, signature, superName, interfaces) } - data object INHERITANCE : Relation() // When a class extends another class - data object INTERFACE_IMPLEMENTATION : Relation() // When a class implements an interface - data class FIELD_TYPE(override val from_method: String) : - Relation() // When a class has a field of another class type - - data class METHOD_PARAMETER( - override val from_method: String - ) : Relation() // When a class has a method that takes another class as a parameter - - data object METHOD_RETURN_TYPE : Relation() // When a class has a method that returns another class - data class LOCAL_VARIABLE(override val from_method: String) : - Relation() // When a method within a class declares a local variable of another class - - data class EXCEPTION_TYPE(override val from_method: String) : - Relation() // When a method declares that it throws an exception of another class - - data class ANNOTATION(override val from_method: String) : - Relation() // When a class, method, or field is annotated with another class (annotation) - - data class INSTANCE_CREATION(override val from_method: String) : - Relation() // When a class creates an instance of another class - - data class METHOD_REFERENCE( - override val from_method: String, - override val to_method: String - ) : Relation() // When a method references another class's method - - data class METHOD_SIGNATURE( - override val from_method: String, - override val to_method: String - ) : Relation() // When a method signature references another class - - data class FIELD_REFERENCE(override val from_method: String) : - Relation() // When a method references another class's field - - data class DYNAMIC_BINDING(override val from_method: String) : - Relation() // When a class uses dynamic binding (e.g., invoke dynamic) related to another class - - - class DependencyClassVisitor( - val dependencies: MutableMap> = mutableMapOf(), - var access: Int = 0, - var methods: MutableMap = mutableMapOf(), - ) : ClassVisitor(Opcodes.ASM9) { - - override fun visit( - version: Int, - access: Int, - name: String, - signature: String?, - superName: String?, - interfaces: Array? - ) { - this.access = access - // Add superclass dependency - superName?.let { addDep(it, INHERITANCE) } - // Add interface dependencies - interfaces?.forEach { addDep(it, INTERFACE_IMPLEMENTATION) } - visitSignature(name, signature) - super.visit(version, access, name, signature, superName, interfaces) - } - - override fun visitField( - access: Int, - name: String?, - desc: String?, - signature: String?, - value: Any? - ): FieldVisitor { - visitSignature(name, signature) - // Add field type dependency - addType(desc, FIELD_TYPE(from_method = "")) - return DependencyFieldVisitor(dependencies) - } + override fun visitField( + access: Int, + name: String?, + desc: String?, + signature: String?, + value: Any? + ): FieldVisitor { + visitSignature(name, signature) + // Add field type dependency + addType(desc, FIELD_TYPE(from_method = "")) + return DependencyFieldVisitor(dependencies) + } - override fun visitMethod( - access: Int, - name: String?, - desc: String?, - signature: String?, - exceptions: Array? - ): MethodVisitor { - visitSignature(name, signature) - // Add method return type and parameter types dependencies - addMethodDescriptor(desc, METHOD_PARAMETER(from_method = name ?: ""), METHOD_RETURN_TYPE) - // Add exception types dependencies - exceptions?.forEach { addDep(it, EXCEPTION_TYPE(from_method = name ?: "")) } - val methodVisitor = DependencyMethodVisitor(name ?: "", dependencies) - methods[methodVisitor.name] = methodVisitor - return methodVisitor - } + override fun visitMethod( + access: Int, + name: String?, + desc: String?, + signature: String?, + exceptions: Array? + ): MethodVisitor { + visitSignature(name, signature) + // Add method return type and parameter types dependencies + addMethodDescriptor(desc, METHOD_PARAMETER(from_method = name ?: ""), METHOD_RETURN_TYPE) + // Add exception types dependencies + exceptions?.forEach { addDep(it, EXCEPTION_TYPE(from_method = name ?: "")) } + val methodVisitor = DependencyMethodVisitor(name ?: "", dependencies) + methods[methodVisitor.name] = methodVisitor + return methodVisitor + } - private fun visitSignature(name: String?, signature: String?) { - // Check if the name indicates an inner class or property accessor - if (name?.contains("$") == true) { - // NOTE: This isn't a typically required dependency - // addDep(name.substringBefore("$"), OUTER_CLASS) - } - if (name?.contains("baseClassLoader") == true) { - signature?.let { - val signatureReader = SignatureReader(it) - signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { - override fun visitClassType(name: String?) { - name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } - } - }) - } - return - } - signature?.let { - val signatureReader = SignatureReader(it) - signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { - override fun visitClassType(name: String?) { - name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } - } - }) + private fun visitSignature(name: String?, signature: String?) { + // Check if the name indicates an inner class or property accessor + if (name?.contains("$") == true) { + // NOTE: This isn't a typically required dependency + // addDep(name.substringBefore("$"), OUTER_CLASS) + } + if (name?.contains("baseClassLoader") == true) { + signature?.let { + val signatureReader = SignatureReader(it) + signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { + override fun visitClassType(name: String?) { + name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } } + }) } + return + } + signature?.let { + val signatureReader = SignatureReader(it) + signatureReader.accept(object : SignatureVisitor(Opcodes.ASM9) { + override fun visitClassType(name: String?) { + name?.let { addDep(it, METHOD_PARAMETER(from_method = "")) } + } + }) + } + } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - // Add annotation type dependency - addType(descriptor, ANNOTATION(from_method = "")) - return super.visitAnnotation(descriptor, visible) - } + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + // Add annotation type dependency + addType(descriptor, ANNOTATION(from_method = "")) + return super.visitAnnotation(descriptor, visible) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) + } - private fun addType(type: String?, relationType: Relation) { - type?.let { - val typeName = Type.getType(it).className - addDep(typeName, relationType) - } - } + private fun addType(type: String?, relationType: Relation) { + type?.let { + val typeName = Type.getType(it).className + addDep(typeName, relationType) + } + } - private fun addMethodDescriptor( - descriptor: String?, - paramRelationType: Relation, - returnRelationType: Relation - ) { - descriptor?.let { - val methodType = Type.getMethodType(it) - // Add return type dependency - addType(methodType.returnType.descriptor, returnRelationType) - // Add parameter types dependencies - methodType.argumentTypes.forEach { argType -> - addType(argType.descriptor, paramRelationType) - } - } + private fun addMethodDescriptor( + descriptor: String?, + paramRelationType: Relation, + returnRelationType: Relation + ) { + descriptor?.let { + val methodType = Type.getMethodType(it) + // Add return type dependency + addType(methodType.returnType.descriptor, returnRelationType) + // Add parameter types dependencies + methodType.argumentTypes.forEach { argType -> + addType(argType.descriptor, paramRelationType) } - + } } - class DependencyFieldVisitor( - val dependencies: MutableMap> - ) : FieldVisitor(Opcodes.ASM9) { + } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - descriptor?.let { addType(it, ANNOTATION(from_method = "")) } - return super.visitAnnotation(descriptor, visible) - } + class DependencyFieldVisitor( + val dependencies: MutableMap> + ) : FieldVisitor(Opcodes.ASM9) { - override fun visitAttribute(attribute: Attribute?) { - super.visitAttribute(attribute) - } - - override fun visitTypeAnnotation( - typeRef: Int, - typePath: TypePath?, - descriptor: String?, - visible: Boolean - ): AnnotationVisitor? { - descriptor?.let { addType(it, ANNOTATION(from_method = "")) } - return super.visitTypeAnnotation(typeRef, typePath, descriptor, visible) - } + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + descriptor?.let { addType(it, ANNOTATION(from_method = "")) } + return super.visitAnnotation(descriptor, visible) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + override fun visitAttribute(attribute: Attribute?) { + super.visitAttribute(attribute) + } - private fun addType(type: String, relationType: Relation) { - addDep(getTypeName(type) ?: return, relationType) - } + override fun visitTypeAnnotation( + typeRef: Int, + typePath: TypePath?, + descriptor: String?, + visible: Boolean + ): AnnotationVisitor? { + descriptor?.let { addType(it, ANNOTATION(from_method = "")) } + return super.visitTypeAnnotation(typeRef, typePath, descriptor, visible) + } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) } - class DependencyMethodVisitor( - val name: String, - val dependencies: MutableMap>, - var access: Int = 0, - ) : MethodVisitor(Opcodes.ASM9) { - - - override fun visitMethodInsn( - opcode: Int, - owner: String?, - name: String?, - descriptor: String?, - isInterface: Boolean - ) { - access = opcode - // Add method reference dependency - owner?.let { addDep(it, METHOD_REFERENCE(from_method = this.name, to_method = name ?: "")) } - // Add method descriptor dependencies (for parameter and return types) - descriptor?.let { - addMethodDescriptor( - it, - METHOD_SIGNATURE(from_method = this.name, to_method = name ?: "") - ) - } - super.visitMethodInsn(opcode, owner, name, descriptor, isInterface) - } + private fun addType(type: String, relationType: Relation) { + addDep(getTypeName(type) ?: return, relationType) + } - override fun visitParameter(name: String?, access: Int) { - // Add method parameter type dependency - name?.let { addType(it, METHOD_PARAMETER(from_method = this.name)) } - super.visitParameter(name, access) - } + } + + class DependencyMethodVisitor( + val name: String, + val dependencies: MutableMap>, + var access: Int = 0, + ) : MethodVisitor(Opcodes.ASM9) { + + + override fun visitMethodInsn( + opcode: Int, + owner: String?, + name: String?, + descriptor: String?, + isInterface: Boolean + ) { + access = opcode + // Add method reference dependency + owner?.let { addDep(it, METHOD_REFERENCE(from_method = this.name, to_method = name ?: "")) } + // Add method descriptor dependencies (for parameter and return types) + descriptor?.let { + addMethodDescriptor( + it, + METHOD_SIGNATURE(from_method = this.name, to_method = name ?: "") + ) + } + super.visitMethodInsn(opcode, owner, name, descriptor, isInterface) + } + override fun visitParameter(name: String?, access: Int) { + // Add method parameter type dependency + name?.let { addType(it, METHOD_PARAMETER(from_method = this.name)) } + super.visitParameter(name, access) + } - override fun visitFieldInsn(opcode: Int, owner: String?, name: String?, descriptor: String?) { - // Add field reference dependency - owner?.let { addDep(it, FIELD_REFERENCE(from_method = this.name)) } - // Add field type dependency - descriptor?.let { addType(it, FIELD_TYPE(from_method = this.name)) } - super.visitFieldInsn(opcode, owner, name, descriptor) - } - override fun visitTypeInsn(opcode: Int, type: String?) { - // Add instance creation or local variable dependency based on opcode - type?.let { - val dependencyType = when (opcode) { - Opcodes.NEW -> INSTANCE_CREATION(from_method = this.name) - else -> LOCAL_VARIABLE(from_method = this.name) - } - addType(it, dependencyType) - } - super.visitTypeInsn(opcode, type) - } + override fun visitFieldInsn(opcode: Int, owner: String?, name: String?, descriptor: String?) { + // Add field reference dependency + owner?.let { addDep(it, FIELD_REFERENCE(from_method = this.name)) } + // Add field type dependency + descriptor?.let { addType(it, FIELD_TYPE(from_method = this.name)) } + super.visitFieldInsn(opcode, owner, name, descriptor) + } - override fun visitLdcInsn(value: Any?) { - // Add class literal dependency - if (value is Type) { - addType(value.descriptor, LOCAL_VARIABLE(from_method = this.name)) - } - super.visitLdcInsn(value) + override fun visitTypeInsn(opcode: Int, type: String?) { + // Add instance creation or local variable dependency based on opcode + type?.let { + val dependencyType = when (opcode) { + Opcodes.NEW -> INSTANCE_CREATION(from_method = this.name) + else -> LOCAL_VARIABLE(from_method = this.name) } + addType(it, dependencyType) + } + super.visitTypeInsn(opcode, type) + } - override fun visitMultiANewArrayInsn(descriptor: String?, numDimensions: Int) { - // Add local variable dependency for multi-dimensional arrays - descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } - super.visitMultiANewArrayInsn(descriptor, numDimensions) - } + override fun visitLdcInsn(value: Any?) { + // Add class literal dependency + if (value is Type) { + addType(value.descriptor, LOCAL_VARIABLE(from_method = this.name)) + } + super.visitLdcInsn(value) + } - override fun visitInvokeDynamicInsn( - name: String?, - descriptor: String?, - bootstrapMethodHandle: Handle?, - vararg bootstrapMethodArguments: Any? - ) { - // Add dynamic binding dependency - descriptor?.let { addMethodDescriptor(it, DYNAMIC_BINDING(from_method = this.name)) } - super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, *bootstrapMethodArguments) - } + override fun visitMultiANewArrayInsn(descriptor: String?, numDimensions: Int) { + // Add local variable dependency for multi-dimensional arrays + descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } + super.visitMultiANewArrayInsn(descriptor, numDimensions) + } - override fun visitLocalVariable( - name: String?, - descriptor: String?, - signature: String?, - start: Label?, - end: Label?, - index: Int - ) { - // Add local variable dependency - descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } - super.visitLocalVariable(name, descriptor, signature, start, end, index) - } + override fun visitInvokeDynamicInsn( + name: String?, + descriptor: String?, + bootstrapMethodHandle: Handle?, + vararg bootstrapMethodArguments: Any? + ) { + // Add dynamic binding dependency + descriptor?.let { addMethodDescriptor(it, DYNAMIC_BINDING(from_method = this.name)) } + super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, *bootstrapMethodArguments) + } - override fun visitTryCatchBlock(start: Label?, end: Label?, handler: Label?, type: String?) { - // Add exception type dependency - type?.let { addType(it, EXCEPTION_TYPE(from_method = this.name)) } - super.visitTryCatchBlock(start, end, handler, type) - } + override fun visitLocalVariable( + name: String?, + descriptor: String?, + signature: String?, + start: Label?, + end: Label?, + index: Int + ) { + // Add local variable dependency + descriptor?.let { addType(it, LOCAL_VARIABLE(from_method = this.name)) } + super.visitLocalVariable(name, descriptor, signature, start, end, index) + } - override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { - // Add annotation type dependency - descriptor?.let { addType(it, ANNOTATION(from_method = this.name)) } - return super.visitAnnotation(descriptor, visible) - } + override fun visitTryCatchBlock(start: Label?, end: Label?, handler: Label?, type: String?) { + // Add exception type dependency + type?.let { addType(it, EXCEPTION_TYPE(from_method = this.name)) } + super.visitTryCatchBlock(start, end, handler, type) + } - private fun addDep(internalName: String, relationType: Relation) { - val typeName = internalName.replace('/', '.') - dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) - } + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor? { + // Add annotation type dependency + descriptor?.let { addType(it, ANNOTATION(from_method = this.name)) } + return super.visitAnnotation(descriptor, visible) + } - private fun addType(type: String, relationType: Relation): Unit { - addDep(getTypeName(type) ?: return, relationType) - } + private fun addDep(internalName: String, relationType: Relation) { + val typeName = internalName.replace('/', '.') + dependencies.getOrPut(typeName) { mutableSetOf() }.add(relationType) + } - private fun addMethodDescriptor( - descriptor: String, - relationType: Relation - ) { - val methodType = Type.getMethodType(descriptor) - // Add return type dependency - addType(methodType.returnType.descriptor, relationType) - // Add parameter types dependencies - methodType.argumentTypes.forEach { addType(it.descriptor, relationType) } - } + private fun addType(type: String, relationType: Relation): Unit { + addDep(getTypeName(type) ?: return, relationType) } - private fun getTypeName(type: String): String? = try { - val name = when { - // For array types, get the class name - type.startsWith("L") && type.endsWith(";") -> getTypeName(type.substring(1, type.length - 1)) - // Handle the case where the descriptor appears to be a plain class name - !type.startsWith("[") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type.classToPath).className - // Handle the case where the descriptor is missing 'L' and ';' - type.contains("/") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type).className - // For primitive types, use the descriptor directly - type.length == 1 && "BCDFIJSZ".contains(type[0]) -> type - type.endsWith("$") -> type.substring(0, type.length - 1) - else -> Type.getType(type).className - } - name - } catch (e: Exception) { - println("Error adding type: $type (${e.message})") - null + private fun addMethodDescriptor( + descriptor: String, + relationType: Relation + ) { + val methodType = Type.getMethodType(descriptor) + // Add return type dependency + addType(methodType.returnType.descriptor, relationType) + // Add parameter types dependencies + methodType.argumentTypes.forEach { addType(it.descriptor, relationType) } } + } + + private fun getTypeName(type: String): String? = try { + val name = when { + // For array types, get the class name + type.startsWith("L") && type.endsWith(";") -> getTypeName(type.substring(1, type.length - 1)) + // Handle the case where the descriptor appears to be a plain class name + !type.startsWith("[") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type.classToPath).className + // Handle the case where the descriptor is missing 'L' and ';' + type.contains("/") && !type.startsWith("L") && !type.endsWith(";") -> Type.getObjectType(type).className + // For primitive types, use the descriptor directly + type.length == 1 && "BCDFIJSZ".contains(type[0]) -> type + type.endsWith("$") -> type.substring(0, type.length - 1) + else -> Type.getType(type).className + } + name + } catch (e: Exception) { + println("Error adding type: $type (${e.message})") + null + } - val String.classToPath - get() = removeSuffix(".class").replace('.', '/') + val String.classToPath + get() = removeSuffix(".class").replace('.', '/') - data class Reference( - val from: String, - val to: String, - val relation: Relation - ) + data class Reference( + val from: String, + val to: String, + val relation: Relation + ) - fun readJarClasses(jarPath: String) = JarFile(jarPath).use { jarFile -> - jarFile.entries().asSequence().filter { it.name.endsWith(".class") }.map { entry -> - val className = entry.name.replace('/', '.').removeSuffix(".class") - className to jarFile.getInputStream(entry)?.readBytes() - }.toMap() - } + fun readJarClasses(jarPath: String) = JarFile(jarPath).use { jarFile -> + jarFile.entries().asSequence().filter { it.name.endsWith(".class") }.map { entry -> + val className = entry.name.replace('/', '.').removeSuffix(".class") + className to jarFile.getInputStream(entry)?.readBytes() + }.toMap() + } - fun readJarFiles(jarPath: String) = JarFile(jarPath).use { jarFile -> - jarFile.entries().asSequence().map { it.name }.toList().toTypedArray() - } + fun readJarFiles(jarPath: String) = JarFile(jarPath).use { jarFile -> + jarFile.entries().asSequence().map { it.name }.toList().toTypedArray() + } - fun downstreamMap(dependencies: List) = - dependencies.groupBy { it.from } + fun downstreamMap(dependencies: List) = + dependencies.groupBy { it.from } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/CommonRoot.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/CommonRoot.kt index c2ef214c..4c9d854e 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/CommonRoot.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/CommonRoot.kt @@ -4,27 +4,27 @@ import java.io.File import java.nio.file.Path fun Array.commonRoot(): Path = when { - isEmpty() -> error("No paths") - size == 1 && first().toFile().isFile -> first().parent - size == 1 -> first() - else -> this.reduce { a, b -> - when { - a.startsWith(b) -> b - b.startsWith(a) -> a - else -> when (val common = a.commonPrefixWith(b)) { - a -> a - b -> b - else -> common.toAbsolutePath() - } - } + isEmpty() -> error("No paths") + size == 1 && first().toFile().isFile -> first().parent + size == 1 -> first() + else -> this.reduce { a, b -> + when { + a.startsWith(b) -> b + b.startsWith(a) -> a + else -> when (val common = a.commonPrefixWith(b)) { + a -> a + b -> b + else -> common.toAbsolutePath() + } } + } } private fun Path.commonPrefixWith(b: Path): Path { - val a = this - val aParts = a.toAbsolutePath().toString().split(File.separator) - val bParts = b.toAbsolutePath().toString().split(File.separator) - val common = aParts.zip(bParts).takeWhile { (a, b) -> a == b }.map { it.first } - return File(File.separator + common.joinToString(File.separator)).toPath() + val a = this + val aParts = a.toAbsolutePath().toString().split(File.separator) + val bParts = b.toAbsolutePath().toString().split(File.separator) + val common = aParts.zip(bParts).takeWhile { (a, b) -> a == b }.map { it.first } + return File(File.separator + common.joinToString(File.separator)).toPath() } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt index b649c3dc..6f6dd4da 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Ears.kt @@ -18,111 +18,111 @@ import java.util.concurrent.atomic.AtomicInteger */ @Suppress("unused") open class Ears( - val api: ChatClient, - private val secondsPerAudioPacket: Double = 0.25, + val api: ChatClient, + private val secondsPerAudioPacket: Double = 0.25, ) { - interface CommandRecognizer { - fun listenForCommand(inputBuffer: DictationBuffer): CommandRecognized + interface CommandRecognizer { + fun listenForCommand(inputBuffer: DictationBuffer): CommandRecognized - data class DictationBuffer( - val text: String? = null, - ) + data class DictationBuffer( + val text: String? = null, + ) - data class CommandRecognized( - val commandRecognized: Boolean? = null, - val command: String? = null, - ) - } + data class CommandRecognized( + val commandRecognized: Boolean? = null, + val command: String? = null, + ) + } - open val commandRecognizer = ChatProxy( - clazz = CommandRecognizer::class.java, - api = api, - model = OpenAIModels.GPT4oMini, - ).create() + open val commandRecognizer = ChatProxy( + clazz = CommandRecognizer::class.java, + api = api, + model = OpenAIModels.GPT4oMini, + ).create() - open fun timeout(ms: Long): () -> Boolean { - val endTime = System.currentTimeMillis() + ms - return { System.currentTimeMillis() < endTime } - } + open fun timeout(ms: Long): () -> Boolean { + val endTime = System.currentTimeMillis() + ms + return { System.currentTimeMillis() < endTime } + } - open fun listenForCommand( - client: OpenAIClient, - minCaptureMs: Int = 1000, - continueFn: () -> Boolean = timeout(120, TimeUnit.SECONDS), - rawBuffer: Deque = startAudioCapture(continueFn), - commandHandler: (command: String) -> Unit, + open fun listenForCommand( + client: OpenAIClient, + minCaptureMs: Int = 1000, + continueFn: () -> Boolean = timeout(120, TimeUnit.SECONDS), + rawBuffer: Deque = startAudioCapture(continueFn), + commandHandler: (command: String) -> Unit, + ) { + val buffer = StringBuilder() + val commandsProcessed = AtomicInteger(0) + var lastCommandCheckTime = System.currentTimeMillis() + startDictationListener( + client, + continueFn = { continueFn() && 0 == commandsProcessed.get() }, + rawBuffer = rawBuffer ) { - val buffer = StringBuilder() - val commandsProcessed = AtomicInteger(0) - var lastCommandCheckTime = System.currentTimeMillis() - startDictationListener( - client, - continueFn = { continueFn() && 0 == commandsProcessed.get() }, - rawBuffer = rawBuffer - ) { - buffer.append(it) - if (System.currentTimeMillis() - lastCommandCheckTime > minCaptureMs) { - log.info("Checking for command: $buffer") - lastCommandCheckTime = System.currentTimeMillis() - val inputBuffer = CommandRecognizer.DictationBuffer(buffer.toString()) - commandRecognizer.listenForCommand(inputBuffer).let { result -> - if (result.commandRecognized == true) { - log.info("Command recognized: ${result.command}") - commandsProcessed.incrementAndGet() - buffer.clear() - if (null != result.command) commandHandler(result.command) - } - } - } + buffer.append(it) + if (System.currentTimeMillis() - lastCommandCheckTime > minCaptureMs) { + log.info("Checking for command: $buffer") + lastCommandCheckTime = System.currentTimeMillis() + val inputBuffer = CommandRecognizer.DictationBuffer(buffer.toString()) + commandRecognizer.listenForCommand(inputBuffer).let { result -> + if (result.commandRecognized == true) { + log.info("Command recognized: ${result.command}") + commandsProcessed.incrementAndGet() + buffer.clear() + if (null != result.command) commandHandler(result.command) + } } + } } + } - open fun startDictationListener( - client: OpenAIClient, - continueFn: () -> Boolean = timeout(60, TimeUnit.SECONDS), - rawBuffer: Deque = startAudioCapture(continueFn), - textAppend: (String) -> Unit, - ) { - val wavBuffer = ConcurrentLinkedDeque() - Thread({ - try { - LookbackLoudnessWindowBuffer(rawBuffer, wavBuffer, continueFn).run() - } catch (e: Throwable) { - e.printStackTrace() - } - }, "dictation-audio-processor").start() - val dictationProcessor = TranscriptionProcessor(client, wavBuffer, continueFn) { - log.info("Dictation: $it") - textAppend(it) - } - val dictationThread = Thread({ - try { - dictationProcessor.run() - } catch (e: Throwable) { - e.printStackTrace() - } - }, "dictation-api-processor") - dictationThread.start() - dictationThread.join() + open fun startDictationListener( + client: OpenAIClient, + continueFn: () -> Boolean = timeout(60, TimeUnit.SECONDS), + rawBuffer: Deque = startAudioCapture(continueFn), + textAppend: (String) -> Unit, + ) { + val wavBuffer = ConcurrentLinkedDeque() + Thread({ + try { + LookbackLoudnessWindowBuffer(rawBuffer, wavBuffer, continueFn).run() + } catch (e: Throwable) { + e.printStackTrace() + } + }, "dictation-audio-processor").start() + val dictationProcessor = TranscriptionProcessor(client, wavBuffer, continueFn) { + log.info("Dictation: $it") + textAppend(it) } + val dictationThread = Thread({ + try { + dictationProcessor.run() + } catch (e: Throwable) { + e.printStackTrace() + } + }, "dictation-api-processor") + dictationThread.start() + dictationThread.join() + } - open fun startAudioCapture(continueFn: () -> Boolean): ConcurrentLinkedDeque { - val rawBuffer = ConcurrentLinkedDeque() - Thread({ - try { - AudioRecorder(rawBuffer, secondsPerAudioPacket, continueFn).run() - } catch (e: Throwable) { - e.printStackTrace() - } - }, "dication-audio-recorder").start() - return rawBuffer - } + open fun startAudioCapture(continueFn: () -> Boolean): ConcurrentLinkedDeque { + val rawBuffer = ConcurrentLinkedDeque() + Thread({ + try { + AudioRecorder(rawBuffer, secondsPerAudioPacket, continueFn).run() + } catch (e: Throwable) { + e.printStackTrace() + } + }, "dication-audio-recorder").start() + return rawBuffer + } - private fun timeout(count: Long, timeUnit: TimeUnit): () -> Boolean = timeout(timeUnit.toMillis(count)) + private fun timeout(count: Long, timeUnit: TimeUnit): () -> Boolean = timeout(timeUnit.toMillis(count)) - companion object { - private val log = LoggerFactory.getLogger(Ears::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(Ears::class.java) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt index f5b75c94..6763943e 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/FunctionWrapper.kt @@ -14,170 +14,170 @@ import java.util.concurrent.atomic.AtomicInteger import javax.imageio.ImageIO class FunctionWrapper(val inner: FunctionInterceptor) : FunctionInterceptor { - inline fun wrap(crossinline fn: () -> T) = inner.intercept(T::class.java) { fn() } - inline fun

wrap(p: P, crossinline fn: (P) -> T) = - inner.intercept(p, T::class.java) { fn(it) } - - inline fun wrap(p1: P1, p2: P2, crossinline fn: (P1, P2) -> T) = - inner.intercept(p1, p2, T::class.java) { p1, p2 -> fn(p1, p2) } - - inline fun wrap( - p1: P1, - p2: P2, - p3: P3, - crossinline fn: (P1, P2, P3) -> T - ) = - inner.intercept(p1, p2, p3, T::class.java) { p1, p2, p3 -> fn(p1, p2, p3) } - - inline fun wrap( - p1: P1, - p2: P2, - p3: P3, - p4: P4, - crossinline fn: (P1, P2, P3, P4) -> T - ) = - inner.intercept(p1, p2, p3, p4, T::class.java) { p1, p2, p3, p4 -> fn(p1, p2, p3, p4) } - - override fun intercept(returnClazz: Class, fn: () -> T) = inner.intercept(returnClazz, fn) - - override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = - inner.intercept(params, returnClazz, fn) - - override fun intercept( - p1: P1, - p2: P2, - returnClazz: Class, - fn: (P1, P2) -> T - ) = inner.intercept(p1, p2, returnClazz, fn) + inline fun wrap(crossinline fn: () -> T) = inner.intercept(T::class.java) { fn() } + inline fun

wrap(p: P, crossinline fn: (P) -> T) = + inner.intercept(p, T::class.java) { fn(it) } + + inline fun wrap(p1: P1, p2: P2, crossinline fn: (P1, P2) -> T) = + inner.intercept(p1, p2, T::class.java) { p1, p2 -> fn(p1, p2) } + + inline fun wrap( + p1: P1, + p2: P2, + p3: P3, + crossinline fn: (P1, P2, P3) -> T + ) = + inner.intercept(p1, p2, p3, T::class.java) { p1, p2, p3 -> fn(p1, p2, p3) } + + inline fun wrap( + p1: P1, + p2: P2, + p3: P3, + p4: P4, + crossinline fn: (P1, P2, P3, P4) -> T + ) = + inner.intercept(p1, p2, p3, p4, T::class.java) { p1, p2, p3, p4 -> fn(p1, p2, p3, p4) } + + override fun intercept(returnClazz: Class, fn: () -> T) = inner.intercept(returnClazz, fn) + + override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = + inner.intercept(params, returnClazz, fn) + + override fun intercept( + p1: P1, + p2: P2, + returnClazz: Class, + fn: (P1, P2) -> T + ) = inner.intercept(p1, p2, returnClazz, fn) } interface FunctionInterceptor { - fun intercept(returnClazz: Class, fn: () -> T) = fn() - fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) - fun intercept(p1: P1, p2: P2, returnClazz: Class, fn: (P1, P2) -> T) = - intercept(listOf(p1, p2), returnClazz) { - @Suppress("UNCHECKED_CAST") - fn(it[0] as P1, it[1] as P2) - } - - fun intercept( - p1: P1, - p2: P2, - p3: P3, - returnClazz: Class, - fn: (P1, P2, P3) -> T - ) = - intercept(listOf(p1, p2, p3), returnClazz) { - @Suppress("UNCHECKED_CAST") - fn(it[0] as P1, it[1] as P2, it[2] as P3) - } - - fun intercept( - p1: P1, - p2: P2, - p3: P3, - p4: P4, - returnClazz: Class, - fn: (P1, P2, P3, P4) -> T - ) = - intercept(listOf(p1, p2, p3, p4), returnClazz) { - @Suppress("UNCHECKED_CAST") - fn(it[0] as P1, it[1] as P2, it[2] as P3, it[3] as P4) - } + fun intercept(returnClazz: Class, fn: () -> T) = fn() + fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) + fun intercept(p1: P1, p2: P2, returnClazz: Class, fn: (P1, P2) -> T) = + intercept(listOf(p1, p2), returnClazz) { + @Suppress("UNCHECKED_CAST") + fn(it[0] as P1, it[1] as P2) + } + + fun intercept( + p1: P1, + p2: P2, + p3: P3, + returnClazz: Class, + fn: (P1, P2, P3) -> T + ) = + intercept(listOf(p1, p2, p3), returnClazz) { + @Suppress("UNCHECKED_CAST") + fn(it[0] as P1, it[1] as P2, it[2] as P3) + } + + fun intercept( + p1: P1, + p2: P2, + p3: P3, + p4: P4, + returnClazz: Class, + fn: (P1, P2, P3, P4) -> T + ) = + intercept(listOf(p1, p2, p3, p4), returnClazz) { + @Suppress("UNCHECKED_CAST") + fn(it[0] as P1, it[1] as P2, it[2] as P3, it[3] as P4) + } } class NoopFunctionInterceptor : FunctionInterceptor { - override fun intercept(returnClazz: Class, fn: () -> T) = fn() - override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) + override fun intercept(returnClazz: Class, fn: () -> T) = fn() + override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T) = fn(params) } class JsonFunctionRecorder(baseDir: File) : FunctionInterceptor, Closeable { - private val baseDirectory = baseDir.apply { - if (exists()) { - throw IllegalStateException("File already exists: $this") - } - mkdirs() + private val baseDirectory = baseDir.apply { + if (exists()) { + throw IllegalStateException("File already exists: $this") } - private val sequenceId = AtomicInteger(0) - - override fun close() { - // No resources to close in this implementation + mkdirs() + } + private val sequenceId = AtomicInteger(0) + + override fun close() { + // No resources to close in this implementation + } + + override fun intercept(returnClazz: Class, fn: () -> T): T { + val dir = operationDir() + try { + val result = fn() + if (result is BufferedImage) { + ImageIO.write(result, "png", File(dir, "output.png")) + } else { + File(dir, "output.json").writeText(JsonUtil.toJson(result)) + } + return result + } catch (e: Throwable) { + try { + File(dir, "error.json").writeText(JsonUtil.toJson(e)) + } catch (e: Throwable) { + log.warn("Error writing error file", e) + } + throw e } - - override fun intercept(returnClazz: Class, fn: () -> T): T { - val dir = operationDir() - try { - val result = fn() - if (result is BufferedImage) { - ImageIO.write(result, "png", File(dir, "output.png")) - } else { - File(dir, "output.json").writeText(JsonUtil.toJson(result)) - } - return result - } catch (e: Throwable) { - try { - File(dir, "error.json").writeText(JsonUtil.toJson(e)) - } catch (e: Throwable) { - log.warn("Error writing error file", e) - } - throw e - } + } + + override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T): T { + val dir = operationDir() + File(dir, "input.json").writeText(JsonUtil.toJson(params)) + try { + val result = fn(params) + if (result is BufferedImage) { + ImageIO.write(result, "png", File(dir, "output.png")) + } else { + File(dir, "output.json").writeText(JsonUtil.toJson(result)) + } + return result + } catch (e: Throwable) { + try { + File(dir, "error.json").writeText(JsonUtil.toJson(e)) + } catch (e: Throwable) { + log.warn("Error writing error file", e) + } + throw e } - - override fun

intercept(params: P, returnClazz: Class, fn: (P) -> T): T { - val dir = operationDir() - File(dir, "input.json").writeText(JsonUtil.toJson(params)) - try { - val result = fn(params) - if (result is BufferedImage) { - ImageIO.write(result, "png", File(dir, "output.png")) - } else { - File(dir, "output.json").writeText(JsonUtil.toJson(result)) - } - return result - } catch (e: Throwable) { - try { - File(dir, "error.json").writeText(JsonUtil.toJson(e)) - } catch (e: Throwable) { - log.warn("Error writing error file", e) - } - throw e - } + } + + private fun operationDir(): File { + val id = sequenceId.incrementAndGet().toString().padStart(3, '0') + val yyyyMMddHHmmss = + java.time.format.DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(java.time.LocalDateTime.now()) + val internalClassList = listOf( + java.lang.Thread::class.java, + JsonFunctionRecorder::class.java, + FunctionWrapper::class.java, + FunctionInterceptor::class.java, + NoopFunctionInterceptor::class.java, + ) + // Get the caller method name from the stack trace (first caller not in internalClassList) + val caller = Thread.currentThread().stackTrace + .filter { !internalClassList.contains(Class.forName(it.className)) } + .filter { it.methodName != "intercept" } + .firstOrNull() + val methodName = caller?.methodName ?: "unknown" + val file = File(baseDirectory, "$id-$yyyyMMddHHmmss-$methodName") + if (file.exists()) { + throw IllegalStateException("File already exists: $file") } + file.mkdirs() + return file + } - private fun operationDir(): File { - val id = sequenceId.incrementAndGet().toString().padStart(3, '0') - val yyyyMMddHHmmss = - java.time.format.DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(java.time.LocalDateTime.now()) - val internalClassList = listOf( - java.lang.Thread::class.java, - JsonFunctionRecorder::class.java, - FunctionWrapper::class.java, - FunctionInterceptor::class.java, - NoopFunctionInterceptor::class.java, - ) - // Get the caller method name from the stack trace (first caller not in internalClassList) - val caller = Thread.currentThread().stackTrace - .filter { !internalClassList.contains(Class.forName(it.className)) } - .filter { it.methodName != "intercept" } - .firstOrNull() - val methodName = caller?.methodName ?: "unknown" - val file = File(baseDirectory, "$id-$yyyyMMddHHmmss-$methodName") - if (file.exists()) { - throw IllegalStateException("File already exists: $file") - } - file.mkdirs() - return file - } - - companion object { - val log = org.slf4j.LoggerFactory.getLogger(JsonFunctionRecorder::class.java) - } + companion object { + val log = org.slf4j.LoggerFactory.getLogger(JsonFunctionRecorder::class.java) + } } fun getModel(modelName: String?): OpenAIModel? = ChatModel.values().values.find { it.modelName == modelName } - ?: EmbeddingModels.values().values.find { it.modelName == modelName } - ?: ImageModels.values().find { it.modelName == modelName } + ?: EmbeddingModels.values().values.find { it.modelName == modelName } + ?: ImageModels.values().find { it.modelName == modelName } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/GetModuleRootForFile.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/GetModuleRootForFile.kt index 84926e62..dc0abdc3 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/GetModuleRootForFile.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/GetModuleRootForFile.kt @@ -3,15 +3,15 @@ package com.simiacryptus.skyenet.core.util import java.io.File fun getModuleRootForFile(file: File): File { - if (file.isFile) { - return getModuleRootForFile(file.parentFile) + if (file.isFile) { + return getModuleRootForFile(file.parentFile) + } + var current = file + do { + if (current.resolve(".git").exists()) { + return current } - var current = file - do { - if (current.resolve(".git").exists()) { - return current - } - current = current.parentFile ?: break - } while (true) - return file + current = current.parentFile ?: break + } while (true) + return file } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/LoggingInterceptor.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/LoggingInterceptor.kt index 37310c00..16ba6244 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/LoggingInterceptor.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/LoggingInterceptor.kt @@ -8,68 +8,68 @@ import org.slf4j.LoggerFactory @Suppress("unused") class LoggingInterceptor( - private val stringBuffer: StringBuffer = StringBuffer(), + private val stringBuffer: StringBuffer = StringBuffer(), ) : AppenderBase() { - companion object { - fun withIntercept( - stringBuffer: StringBuffer = StringBuffer(), - vararg loggerPrefixes: String, - fn: () -> T, - ): T { - val loggerContext = LoggerFactory.getILoggerFactory() as LoggerContext - val loggers = loggerPrefixes.flatMap { loggerPrefix -> - loggerContext.loggerList.filter { it.name.startsWith(loggerPrefix) } - } - return withIntercept( - stringBuffer = stringBuffer, - loggers = loggers.toTypedArray(), - fn = fn - ) - } + companion object { + fun withIntercept( + stringBuffer: StringBuffer = StringBuffer(), + vararg loggerPrefixes: String, + fn: () -> T, + ): T { + val loggerContext = LoggerFactory.getILoggerFactory() as LoggerContext + val loggers = loggerPrefixes.flatMap { loggerPrefix -> + loggerContext.loggerList.filter { it.name.startsWith(loggerPrefix) } + } + return withIntercept( + stringBuffer = stringBuffer, + loggers = loggers.toTypedArray(), + fn = fn + ) + } - private fun withIntercept( - stringBuffer: StringBuffer, - vararg loggers: Logger, - fn: () -> T, - ): T { - // Save the original logger level and appender list - val originalLevels = loggers.map { it.level } - val originalAppenders = loggers.map { it.iteratorForAppenders().asSequence().toList() } + private fun withIntercept( + stringBuffer: StringBuffer, + vararg loggers: Logger, + fn: () -> T, + ): T { + // Save the original logger level and appender list + val originalLevels = loggers.map { it.level } + val originalAppenders = loggers.map { it.iteratorForAppenders().asSequence().toList() } - // Create and attach the custom StringBufferAppender - val stringBufferAppender = LoggingInterceptor(stringBuffer) - stringBufferAppender.context = LoggerFactory.getILoggerFactory() as LoggerContext - stringBufferAppender.start() - loggers.forEach { it.detachAndStopAllAppenders() } - loggers.forEach { it.addAppender(stringBufferAppender) } + // Create and attach the custom StringBufferAppender + val stringBufferAppender = LoggingInterceptor(stringBuffer) + stringBufferAppender.context = LoggerFactory.getILoggerFactory() as LoggerContext + stringBufferAppender.start() + loggers.forEach { it.detachAndStopAllAppenders() } + loggers.forEach { it.addAppender(stringBufferAppender) } - try { - return fn() - } finally { - // Restore the original logger level and appender list - loggers.zip(originalLevels.zip(originalAppenders)).forEach { (jsEngineLogger, t) -> - val (originalLevel, originalAppender) = t - jsEngineLogger.level = originalLevel - jsEngineLogger.detachAndStopAllAppenders() - originalAppender.forEach { jsEngineLogger.addAppender(it) } - } - } + try { + return fn() + } finally { + // Restore the original logger level and appender list + loggers.zip(originalLevels.zip(originalAppenders)).forEach { (jsEngineLogger, t) -> + val (originalLevel, originalAppender) = t + jsEngineLogger.level = originalLevel + jsEngineLogger.detachAndStopAllAppenders() + originalAppender.forEach { jsEngineLogger.addAppender(it) } } + } } + } - override fun addInfo(msg: String?, ex: Throwable?) { - super.addInfo(msg, ex) - } + override fun addInfo(msg: String?, ex: Throwable?) { + super.addInfo(msg, ex) + } - override fun append(event: ILoggingEvent) { - stringBuffer.append(event.formattedMessage) - event.throwableProxy?.let { - stringBuffer.append(System.lineSeparator()) - stringBuffer.append(ch.qos.logback.classic.pattern.ThrowableProxyConverter().convert(event)) - } - stringBuffer.append(System.lineSeparator()) + override fun append(event: ILoggingEvent) { + stringBuffer.append(event.formattedMessage) + event.throwableProxy?.let { + stringBuffer.append(System.lineSeparator()) + stringBuffer.append(ch.qos.logback.classic.pattern.ThrowableProxyConverter().convert(event)) } + stringBuffer.append(System.lineSeparator()) + } - fun getStringBuffer(): StringBuffer = stringBuffer + fun getStringBuffer(): StringBuffer = stringBuffer } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/MultiExeption.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/MultiExeption.kt index ece5ae42..910ae5a2 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/MultiExeption.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/MultiExeption.kt @@ -1,5 +1,5 @@ package com.simiacryptus.skyenet.core.util class MultiExeption(exceptions: Collection) : RuntimeException( - exceptions.joinToString("\n\n") { "```text\n${/*escapeHtml4*/(it.stackTraceToString())/*.indent(" ")*/}\n```" } + exceptions.joinToString("\n\n") { "```text\n${/*escapeHtml4*/(it.stackTraceToString())/*.indent(" ")*/}\n```" } ) diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt index c58a4b08..c289df92 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilder.kt @@ -6,53 +6,53 @@ import java.util.stream.Collectors object RuleTreeBuilder { - val String.escape get() = replace("$", "\\$") - - fun String.safeSubstring(from: Int, to: Int?) = when { - to == null -> "" - from >= to -> "" - from < 0 -> "" - to > length -> "" - else -> substring(from, to) - } - - @Language("kotlin") - fun getRuleExpression( - toMatch: Set, - doNotMatch: SortedSet, - result: Boolean - ): String = if (doNotMatch.size < toMatch.size) { - getRuleExpression(doNotMatch, toMatch.toSortedSet(), !result) - } else """ + val String.escape get() = replace("$", "\\$") + + fun String.safeSubstring(from: Int, to: Int?) = when { + to == null -> "" + from >= to -> "" + from < 0 -> "" + to > length -> "" + else -> substring(from, to) + } + + @Language("kotlin") + fun getRuleExpression( + toMatch: Set, + doNotMatch: SortedSet, + result: Boolean + ): String = if (doNotMatch.size < toMatch.size) { + getRuleExpression(doNotMatch, toMatch.toSortedSet(), !result) + } else """ when { ${getRules(toMatch.toSet(), doNotMatch.toSortedSet(), result).replace("\n", "\n ")} else -> ${!result} } """.trimIndent().trim() - private fun getRules( - toMatch: Set, - doNotMatch: SortedSet, - result: Boolean - ): String { - if (doNotMatch.isEmpty()) return "true -> $result\n" - val sb: StringBuilder = StringBuilder() - val remainingItems = toMatch.toMutableSet() - fun String.bestPrefix(): String { - val pfx = allowedPrefixes(setOf(this), doNotMatch).firstOrNull() ?: this - require(pfx.isNotBlank()) - //require(doNotMatch.none { it.startsWith(pfx) }) - return pfx - } - while (remainingItems.isNotEmpty()) { + private fun getRules( + toMatch: Set, + doNotMatch: SortedSet, + result: Boolean + ): String { + if (doNotMatch.isEmpty()) return "true -> $result\n" + val sb: StringBuilder = StringBuilder() + val remainingItems = toMatch.toMutableSet() + fun String.bestPrefix(): String { + val pfx = allowedPrefixes(setOf(this), doNotMatch).firstOrNull() ?: this + require(pfx.isNotBlank()) + //require(doNotMatch.none { it.startsWith(pfx) }) + return pfx + } + while (remainingItems.isNotEmpty()) { - val bestNextPrefix = bestPrefix(remainingItems.toSortedSet(), doNotMatch) + val bestNextPrefix = bestPrefix(remainingItems.toSortedSet(), doNotMatch) // val doNotMatchReversed = remainingItems.map { it.reversed() }.toSortedSet() // fun String.bestSuffix() = allowedPrefixes(setOf(this).map { it.reversed() }, doNotMatchReversed).firstOrNull()?.reversed() ?: this // val bestNextSuffix = bestNextSuffix(remainingItems, doNotMatchReversed, sortedItems) - when { + when { // bestNextSuffix != null && bestNextSuffix.second > (bestNextPrefix?.second ?: 0) -> { // val matchedItems = remainingItems.filter { it.endsWith(bestNextSuffix.first) }.toSet() // val matchedSuffixes = matchedItems.map { it.bestSuffix() }.toSet() @@ -78,87 +78,87 @@ object RuleTreeBuilder { // } // remainingItems.removeAll(matchedItems) // } - bestNextPrefix == null -> break - else -> { - val matchedItems = remainingItems.filter { it.startsWith(bestNextPrefix) }.toSet() - val matchedBlacklist = doNotMatch.filter { it.startsWith(bestNextPrefix) } - when { - matchedBlacklist.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> $result""" + "\n") - matchedItems.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> ${!result}""" + "\n") - (matchedItems + matchedBlacklist).map { it.bestPrefix() }.distinct().size < 3 -> break - else -> { - val subRules = getRuleExpression( - matchedItems.map { it.removePrefix(bestNextPrefix) }.toSet(), - matchedBlacklist.map { it.removePrefix(bestNextPrefix) }.toSortedSet(), - result - ) - sb.append( - """ + bestNextPrefix == null -> break + else -> { + val matchedItems = remainingItems.filter { it.startsWith(bestNextPrefix) }.toSet() + val matchedBlacklist = doNotMatch.filter { it.startsWith(bestNextPrefix) } + when { + matchedBlacklist.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> $result""" + "\n") + matchedItems.isEmpty() -> sb.append("""path.startsWith("${bestNextPrefix.bestPrefix().escape}") -> ${!result}""" + "\n") + (matchedItems + matchedBlacklist).map { it.bestPrefix() }.distinct().size < 3 -> break + else -> { + val subRules = getRuleExpression( + matchedItems.map { it.removePrefix(bestNextPrefix) }.toSet(), + matchedBlacklist.map { it.removePrefix(bestNextPrefix) }.toSortedSet(), + result + ) + sb.append( + """ path.startsWith("${bestNextPrefix.escape}") -> { val path = path.substring(${bestNextPrefix.length}) ${subRules.replace("\n", "\n ")} } """.trimIndent() + "\n" - ) - } - } - remainingItems.removeAll(matchedItems) - } + ) } + } + remainingItems.removeAll(matchedItems) } - remainingItems.map { it.bestPrefix() }.toSortedSet().forEach { - require(doNotMatch.none { prefix -> prefix.startsWith(it) }) - sb.append("""path.startsWith("${it.escape}") -> $result""" + "\n") - } - return sb.toString() + } } - - private fun bestPrefix( - positiveSet: SortedSet, - negativeSet: SortedSet - ) = allowedPrefixes(positiveSet, negativeSet) - .parallelStream() - .flatMap { prefixExpand(listOf(it)).stream() } - .filter { it.isNotBlank() } - .map { prefix -> - val goodCnt = positiveSet.subSet(prefix, prefix + "\uFFFF").size - val badCnt = negativeSet.subSet(prefix, prefix + "\uFFFF").size - if (badCnt == 0) return@map prefix to (goodCnt - 1).toDouble() * prefix.length - //if (goodCnt == 0) return@map prefix to (badCnt - 1).toDouble() * prefix.length - val totalCnt = goodCnt + badCnt - val goodFactor = goodCnt.toDouble() / totalCnt - val badFactor = badCnt.toDouble() / totalCnt - val entropy = goodFactor * Math.log(goodFactor) + badFactor * Math.log(badFactor) - prefix to entropy - }.reduce({ a, b -> if (a.second >= b.second) a else b }).orElse(null)?.first - - fun prefixExpand(allowedPrefixes: Collection) = - allowedPrefixes.filter { allowedPrefixes.none { prefix -> prefix != it && prefix.startsWith(it) } } - .flatMap { prefixExpand(it) }.toSet() - - private fun prefixExpand(it: String) = (1..it.length).map { i -> it.substring(0, i) } - - fun allowedPrefixes( - items: Collection, - doNotMatch: SortedSet - ) = items.toList().parallelStream().map { item -> - val list = listOf( - item.safeSubstring( - 0, - longestCommonPrefix(doNotMatch.tailSet(item).firstOrNull(), item)?.length?.let { it + 1 }), - item.safeSubstring( - 0, - longestCommonPrefix(doNotMatch.headSet(item).lastOrNull(), item)?.length?.let { it + 1 }), - ) - list.maxByOrNull { it.length } ?: list.firstOrNull() - }.distinct().collect(Collectors.toSet()).filterNotNull().filter { it.isNotBlank() }.toSortedSet() - - fun longestCommonPrefix(a: String?, b: String?): String? { - if (a == null || b == null) return null - var i = 0 - while (i < a.length && i < b.length && a[i] == b[i]) i++ - return a.substring(0, i) + remainingItems.map { it.bestPrefix() }.toSortedSet().forEach { + require(doNotMatch.none { prefix -> prefix.startsWith(it) }) + sb.append("""path.startsWith("${it.escape}") -> $result""" + "\n") } + return sb.toString() + } + + private fun bestPrefix( + positiveSet: SortedSet, + negativeSet: SortedSet + ) = allowedPrefixes(positiveSet, negativeSet) + .parallelStream() + .flatMap { prefixExpand(listOf(it)).stream() } + .filter { it.isNotBlank() } + .map { prefix -> + val goodCnt = positiveSet.subSet(prefix, prefix + "\uFFFF").size + val badCnt = negativeSet.subSet(prefix, prefix + "\uFFFF").size + if (badCnt == 0) return@map prefix to (goodCnt - 1).toDouble() * prefix.length + //if (goodCnt == 0) return@map prefix to (badCnt - 1).toDouble() * prefix.length + val totalCnt = goodCnt + badCnt + val goodFactor = goodCnt.toDouble() / totalCnt + val badFactor = badCnt.toDouble() / totalCnt + val entropy = goodFactor * Math.log(goodFactor) + badFactor * Math.log(badFactor) + prefix to entropy + }.reduce({ a, b -> if (a.second >= b.second) a else b }).orElse(null)?.first + + fun prefixExpand(allowedPrefixes: Collection) = + allowedPrefixes.filter { allowedPrefixes.none { prefix -> prefix != it && prefix.startsWith(it) } } + .flatMap { prefixExpand(it) }.toSet() + + private fun prefixExpand(it: String) = (1..it.length).map { i -> it.substring(0, i) } + + fun allowedPrefixes( + items: Collection, + doNotMatch: SortedSet + ) = items.toList().parallelStream().map { item -> + val list = listOf( + item.safeSubstring( + 0, + longestCommonPrefix(doNotMatch.tailSet(item).firstOrNull(), item)?.length?.let { it + 1 }), + item.safeSubstring( + 0, + longestCommonPrefix(doNotMatch.headSet(item).lastOrNull(), item)?.length?.let { it + 1 }), + ) + list.maxByOrNull { it.length } ?: list.firstOrNull() + }.distinct().collect(Collectors.toSet()).filterNotNull().filter { it.isNotBlank() }.toSortedSet() + + fun longestCommonPrefix(a: String?, b: String?): String? { + if (a == null || b == null) return null + var i = 0 + while (i < a.length && i < b.length && a[i] == b[i]) i++ + return a.substring(0, i) + } } diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt index 421ea883..88e25fb3 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/Selenium.kt @@ -3,11 +3,11 @@ package com.simiacryptus.skyenet.core.util import java.net.URL interface Selenium : AutoCloseable { - fun save( - url: URL, - currentFilename: String?, - saveRoot: String - ) + fun save( + url: URL, + currentFilename: String?, + saveRoot: String + ) // // open fun setCookies( // driver: WebDriver, diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt index efe32533..39dee84a 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/core/util/StringSplitter.kt @@ -1,32 +1,32 @@ package com.simiacryptus.skyenet.core.util object StringSplitter { - fun split(text: String, seperators: Map): Pair { - val splitAt = seperators.entries.map { (sep, weight) -> - val splitPoint = (0 until (text.length - sep.length)).filter { i -> - text.substring(i, i + sep.length) == sep - }.map { i -> - val a = i.toDouble() / text.length - val b = 1.0 - a - i to b * Math.log(a) + a * Math.log(b) - }.maxByOrNull { it.second } - if (null == splitPoint) null - else sep to ((splitPoint.first + sep.length) to splitPoint.second / weight) - }.filterNotNull().maxByOrNull { it.second.second }?.second?.first ?: (text.length / 2) - return text.substring(0, splitAt) to text.substring(splitAt) - } + fun split(text: String, seperators: Map): Pair { + val splitAt = seperators.entries.map { (sep, weight) -> + val splitPoint = (0 until (text.length - sep.length)).filter { i -> + text.substring(i, i + sep.length) == sep + }.map { i -> + val a = i.toDouble() / text.length + val b = 1.0 - a + i to b * Math.log(a) + a * Math.log(b) + }.maxByOrNull { it.second } + if (null == splitPoint) null + else sep to ((splitPoint.first + sep.length) to splitPoint.second / weight) + }.filterNotNull().maxByOrNull { it.second.second }?.second?.first ?: (text.length / 2) + return text.substring(0, splitAt) to text.substring(splitAt) + } - @JvmStatic - fun main(args: Array) { - println( - split( - text = "This is a test. This is only a test. If this were a real emergency, you would be instructed to panic.", - seperators = mapOf( - "." to 2.0, - " " to 1.0, - ", " to 2.0, - ) - ).toList().joinToString("\n") + @JvmStatic + fun main(args: Array) { + println( + split( + text = "This is a test. This is only a test. If this were a real emergency, you would be instructed to panic.", + seperators = mapOf( + "." to 2.0, + " " to 1.0, + ", " to 2.0, ) - } + ).toList().joinToString("\n") + ) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt index ecc7d076..ffce7a31 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/Interpreter.kt @@ -2,42 +2,42 @@ package com.simiacryptus.skyenet.interpreter interface Interpreter { - fun getLanguage(): String - fun getSymbols(): Map - fun run(code: String): Any? - fun validate(code: String): Throwable? + fun getLanguage(): String + fun getSymbols(): Map + fun run(code: String): Any? + fun validate(code: String): Throwable? - fun wrapCode(code: String): String = code - fun wrapExecution(fn: java.util.function.Supplier): T? = fn.get() + fun wrapCode(code: String): String = code + fun wrapExecution(fn: java.util.function.Supplier): T? = fn.get() - companion object { - private class TestObject { - @Suppress("unused") - fun square(x: Int): Int = x * x - } + companion object { + private class TestObject { + @Suppress("unused") + fun square(x: Int): Int = x * x + } - private interface TestInterface { - fun square(x: Int): Int - } + private interface TestInterface { + fun square(x: Int): Int + } - @JvmStatic - fun test(factory: java.util.function.Function, Interpreter>) { - val testImpl = object : TestInterface { - override fun square(x: Int): Int = x * x - } - with(factory.apply(mapOf("message" to "hello"))) { - test("hello", run("message")) - } - with(factory.apply(mapOf("function" to TestObject()))) { - test(25, run("function.square(5)")) - } - with(factory.apply(mapOf("function" to testImpl))) { - test(25, run("function.square(5)")) - } - } + @JvmStatic + fun test(factory: java.util.function.Function, Interpreter>) { + val testImpl = object : TestInterface { + override fun square(x: Int): Int = x * x + } + with(factory.apply(mapOf("message" to "hello"))) { + test("hello", run("message")) + } + with(factory.apply(mapOf("function" to TestObject()))) { + test(25, run("function.square(5)")) + } + with(factory.apply(mapOf("function" to testImpl))) { + test(25, run("function.square(5)")) + } + } - private fun test(expected: T, actual: T?) { - require(expected == actual) { actual.toString() } - } + private fun test(expected: T, actual: T?) { + require(expected == actual) { actual.toString() } } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt index 25ca0e3f..d8df01a4 100644 --- a/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt +++ b/core/src/main/kotlin/com/simiacryptus/skyenet/interpreter/InterpreterTestBase.kt @@ -6,82 +6,82 @@ import org.junit.jupiter.api.assertThrows abstract class InterpreterTestBase { - @Test - fun `test run with valid code`() { - val interpreter = newInterpreter(mapOf()) - val result = interpreter.run("2 + 2") - Assertions.assertEquals(4, result) - } - - @Test - fun `test run with invalid code`() { - val interpreter = newInterpreter(mapOf()) - assertThrows { interpreter.run("2 +") } - } - - @Test - fun `test validate with valid code`() { - val interpreter = newInterpreter(mapOf()) - val result = interpreter.validate("2 + 2") - Assertions.assertEquals(null, result) - } - - @Test - fun `test validate with invalid code`() { - val interpreter = newInterpreter(mapOf()) - assertThrows { with(interpreter.validate("2 +")) { throw this!! } } - } - - @Test - open fun `test run with variables`() { - val interpreter = newInterpreter(mapOf("x" to (2 as Any), "y" to (3 as Any))) - val result = interpreter.run("x * y") - Assertions.assertEquals(6, result) - } - - @Test - open fun `test validate with variables`() { - val interpreter = newInterpreter(mapOf("x" to (2 as Any), "y" to (3 as Any))) - val result = interpreter.validate("x * y") - Assertions.assertEquals(null, result) - } - - class FooBar { - fun bar() = "Foo says Hello World" - } - - @Test - fun `test run with tool Any`() { - val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) - val result = interpreter.run("tool.bar()") - Assertions.assertEquals("Foo says Hello World", result) - } - - @Test - fun `test validate with tool Any`() { - val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) - val result = interpreter.validate("tool.bar()") - Assertions.assertEquals(null, result) - } - - - @Test - fun `test run with tool Any and invalid code`() { - val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) - assertThrows { interpreter.run("tool.baz()") } - } - - @Test - open fun `test validate with tool Any and invalid code`() { - val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) - assertThrows { with(interpreter.validate("tool.baz()")) { throw this!! } } - } - - @Test - open fun `test validate with undefined variable`() { - val interpreter = newInterpreter(mapOf()) - assertThrows { with(interpreter.validate("x * y")) { throw this!! } } - } - - abstract fun newInterpreter(map: Map): Interpreter + @Test + fun `test run with valid code`() { + val interpreter = newInterpreter(mapOf()) + val result = interpreter.run("2 + 2") + Assertions.assertEquals(4, result) + } + + @Test + fun `test run with invalid code`() { + val interpreter = newInterpreter(mapOf()) + assertThrows { interpreter.run("2 +") } + } + + @Test + fun `test validate with valid code`() { + val interpreter = newInterpreter(mapOf()) + val result = interpreter.validate("2 + 2") + Assertions.assertEquals(null, result) + } + + @Test + fun `test validate with invalid code`() { + val interpreter = newInterpreter(mapOf()) + assertThrows { with(interpreter.validate("2 +")) { throw this!! } } + } + + @Test + open fun `test run with variables`() { + val interpreter = newInterpreter(mapOf("x" to (2 as Any), "y" to (3 as Any))) + val result = interpreter.run("x * y") + Assertions.assertEquals(6, result) + } + + @Test + open fun `test validate with variables`() { + val interpreter = newInterpreter(mapOf("x" to (2 as Any), "y" to (3 as Any))) + val result = interpreter.validate("x * y") + Assertions.assertEquals(null, result) + } + + class FooBar { + fun bar() = "Foo says Hello World" + } + + @Test + fun `test run with tool Any`() { + val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) + val result = interpreter.run("tool.bar()") + Assertions.assertEquals("Foo says Hello World", result) + } + + @Test + fun `test validate with tool Any`() { + val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) + val result = interpreter.validate("tool.bar()") + Assertions.assertEquals(null, result) + } + + + @Test + fun `test run with tool Any and invalid code`() { + val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) + assertThrows { interpreter.run("tool.baz()") } + } + + @Test + open fun `test validate with tool Any and invalid code`() { + val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Any))) + assertThrows { with(interpreter.validate("tool.baz()")) { throw this!! } } + } + + @Test + open fun `test validate with undefined variable`() { + val interpreter = newInterpreter(mapOf()) + assertThrows { with(interpreter.validate("x * y")) { throw this!! } } + } + + abstract fun newInterpreter(map: Map): Interpreter } \ No newline at end of file diff --git a/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt b/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt index 639a7085..60c7c80d 100644 --- a/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt +++ b/core/src/test/kotlin/com/simiacryptus/skyenet/core/util/RuleTreeBuilderTest.kt @@ -7,23 +7,23 @@ import org.junit.jupiter.api.Test class RuleTreeBuilderTest { - @Test - fun testEscape() { - Assertions.assertEquals("\\$100", "$100".escape) - Assertions.assertEquals("NoSpecialCharacters", "NoSpecialCharacters".escape) - Assertions.assertEquals("\\$\\$", "$$".escape) - } + @Test + fun testEscape() { + Assertions.assertEquals("\\$100", "$100".escape) + Assertions.assertEquals("NoSpecialCharacters", "NoSpecialCharacters".escape) + Assertions.assertEquals("\\$\\$", "$$".escape) + } - @Test - fun testSafeSubstring() { - val testString = "HelloWorld" - Assertions.assertEquals("", testString.safeSubstring(-1, 5)) - Assertions.assertEquals("", testString.safeSubstring(0, 11)) - Assertions.assertEquals("", testString.safeSubstring(5, 5)) - Assertions.assertEquals("", testString.safeSubstring(0, null)) - Assertions.assertEquals("Hello", testString.safeSubstring(0, 5)) - Assertions.assertEquals("World", testString.safeSubstring(5, 10)) - } + @Test + fun testSafeSubstring() { + val testString = "HelloWorld" + Assertions.assertEquals("", testString.safeSubstring(-1, 5)) + Assertions.assertEquals("", testString.safeSubstring(0, 11)) + Assertions.assertEquals("", testString.safeSubstring(5, 5)) + Assertions.assertEquals("", testString.safeSubstring(0, null)) + Assertions.assertEquals("Hello", testString.safeSubstring(0, 5)) + Assertions.assertEquals("World", testString.safeSubstring(5, 10)) + } // @Test // fun testBestNextPrefix() { @@ -44,27 +44,27 @@ class RuleTreeBuilderTest { // Assertions.assertEquals("e", bestNextSuffix?.first) // } - @Test - fun testPrefixExpand() { - val allowedPrefixes = setOf("app", "ban") - val expandedPrefixes = RuleTreeBuilder.prefixExpand(allowedPrefixes) - Assertions.assertTrue(expandedPrefixes.containsAll(setOf("a", "ap", "app", "b", "ba", "ban"))) - } + @Test + fun testPrefixExpand() { + val allowedPrefixes = setOf("app", "ban") + val expandedPrefixes = RuleTreeBuilder.prefixExpand(allowedPrefixes) + Assertions.assertTrue(expandedPrefixes.containsAll(setOf("a", "ap", "app", "b", "ba", "ban"))) + } - @Test - fun testAllowedPrefixes() { - val items = listOf("apple", "apricot") - val doNotMatch = sortedSetOf("application", "appetizer") - val allowedPrefixes = RuleTreeBuilder.allowedPrefixes(items, doNotMatch) - Assertions.assertEquals(sortedSetOf("apple", "apr"), allowedPrefixes) - } + @Test + fun testAllowedPrefixes() { + val items = listOf("apple", "apricot") + val doNotMatch = sortedSetOf("application", "appetizer") + val allowedPrefixes = RuleTreeBuilder.allowedPrefixes(items, doNotMatch) + Assertions.assertEquals(sortedSetOf("apple", "apr"), allowedPrefixes) + } - @Test - fun testLongestCommonPrefix() { - Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix(null, "test")) - Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix("test", null)) - Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("", "test")) - Assertions.assertEquals("te", RuleTreeBuilder.longestCommonPrefix("test", "teapot")) - Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("test", "best")) - } + @Test + fun testLongestCommonPrefix() { + Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix(null, "test")) + Assertions.assertNull(RuleTreeBuilder.longestCommonPrefix("test", null)) + Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("", "test")) + Assertions.assertEquals("te", RuleTreeBuilder.longestCommonPrefix("test", "teapot")) + Assertions.assertEquals("", RuleTreeBuilder.longestCommonPrefix("test", "best")) + } } \ No newline at end of file diff --git a/gradle.properties b/gradle.properties index 71477b60..110c5f0e 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ # Gradle Releases -> https://github.com/gradle/gradle/releases libraryGroup=com.simiacryptus.skyenet -libraryVersion=1.2.15 +libraryVersion=1.2.16 gradleVersion=7.6.1 kotlin.daemon.jvmargs=-Xmx4g diff --git a/groovy/build.gradle.kts b/groovy/build.gradle.kts index bcf75ffa..4538f8ca 100644 --- a/groovy/build.gradle.kts +++ b/groovy/build.gradle.kts @@ -6,147 +6,147 @@ group = properties("libraryGroup") version = properties("libraryVersion") plugins { - java - `java-library` - id("org.jetbrains.kotlin.jvm") version "2.0.20" - `maven-publish` - id("signing") + java + `java-library` + id("org.jetbrains.kotlin.jvm") version "2.0.20" + `maven-publish` + id("signing") } repositories { - mavenCentral { - metadataSources { - mavenPom() - artifact() - } + mavenCentral { + metadataSources { + mavenPom() + artifact() } + } } kotlin { - jvmToolchain(11) + jvmToolchain(11) } val kotlin_version = "2.0.20" dependencies { - implementation(project(":core")) + implementation(project(":core")) - implementation(group = "org.apache.groovy", name = "groovy-all", version = "4.0.11") + implementation(group = "org.apache.groovy", name = "groovy-all", version = "4.0.11") - compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") - compileOnly(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version) - compileOnly(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version) + compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") + compileOnly(group = "org.jetbrains.kotlin", name = "kotlin-stdlib", version = kotlin_version) + compileOnly(group = "org.jetbrains.kotlin", name = "kotlin-stdlib-jdk8", version = kotlin_version) - implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") - implementation(group = "commons-io", name = "commons-io", version = "2.15.0") + implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") } tasks { - compileKotlin { - compilerOptions { - javaParameters.set(true) - } + compileKotlin { + compilerOptions { + javaParameters.set(true) } - compileTestKotlin { - compilerOptions { - javaParameters.set(true) - } + } + compileTestKotlin { + compilerOptions { + javaParameters.set(true) } - test { - useJUnitPlatform() - systemProperty("surefire.useManifestOnlyJar", "false") - testLogging { - events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - } - jvmArgs( - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED" - ) + } + test { + useJUnitPlatform() + systemProperty("surefire.useManifestOnlyJar", "false") + testLogging { + events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL } + jvmArgs( + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", "java.base/java.util=ALL-UNNAMED", + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + } } val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") - from(tasks.javadoc) + archiveClassifier.set("javadoc") + from(tasks.javadoc) } val sourcesJar by tasks.registering(Jar::class) { - archiveClassifier.set("sources") - from(sourceSets.main.get().allSource) + archiveClassifier.set("sources") + from(sourceSets.main.get().allSource) } publishing { - publications { - create("mavenJava") { - artifactId = "groovy" - from(components["java"]) - artifact(sourcesJar.get()) - artifact(javadocJar.get()) - versionMapping { - usage("java-api") { - fromResolutionOf("runtimeClasspath") - } - usage("java-runtime") { - fromResolutionResult() - } - } - pom { - name.set("SkyeNet Groovy Interpreter") - description.set("A very helpful puppy") - url.set("https://github.com/SimiaCryptus/SkyeNet") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id.set("acharneski") - name.set("Andrew Charneski") - email.set("acharneski@gmail.com") - } - } - scm { - connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") - developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") - url.set("https://github.com/SimiaCryptus/SkyeNet") - } - } + publications { + create("mavenJava") { + artifactId = "groovy" + from(components["java"]) + artifact(sourcesJar.get()) + artifact(javadocJar.get()) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") } - } - repositories { - maven { - val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" - val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" - url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - credentials { - username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") - ?: properties("ossrhUsername") - password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") - ?: properties("ossrhPassword") - } + usage("java-runtime") { + fromResolutionResult() } - } - if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { - signing { - sign(publications["mavenJava"]) + } + pom { + name.set("SkyeNet Groovy Interpreter") + description.set("A very helpful puppy") + url.set("https://github.com/SimiaCryptus/SkyeNet") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } } + developers { + developer { + id.set("acharneski") + name.set("Andrew Charneski") + email.set("acharneski@gmail.com") + } + } + scm { + connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") + developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") + url.set("https://github.com/SimiaCryptus/SkyeNet") + } + } + } + } + repositories { + maven { + val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" + url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + credentials { + username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") + ?: properties("ossrhUsername") + password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") + ?: properties("ossrhPassword") + } } + } + if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { + signing { + sign(publications["mavenJava"]) + } + } } if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) { - apply() - configure { - useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - sign(configurations.archives.get()) - } + apply() + configure { + useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + sign(configurations.archives.get()) + } } diff --git a/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt b/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt index a482d8ce..07a62c46 100644 --- a/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt +++ b/groovy/src/main/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreter.kt @@ -8,38 +8,38 @@ import org.codehaus.groovy.control.CompilerConfiguration open class GroovyInterpreter(private val defs: java.util.Map) : Interpreter { - private val shell: GroovyShell - - init { - val compilerConfiguration = CompilerConfiguration() - shell = GroovyShell(compilerConfiguration) - defs.forEach { key, value -> - shell.setVariable(key, value) - } - } + private val shell: GroovyShell - override fun getLanguage(): String { - return "groovy" + init { + val compilerConfiguration = CompilerConfiguration() + shell = GroovyShell(compilerConfiguration) + defs.forEach { key, value -> + shell.setVariable(key, value) } + } - override fun getSymbols() = defs as Map + override fun getLanguage(): String { + return "groovy" + } + override fun getSymbols() = defs as Map - override fun run(code: String): Any? { - val wrapExecution = wrapExecution { - try { - val script: Script = shell.parse(wrapCode(code)) - script.run() - } catch (e: CompilationFailedException) { - throw e - } - } - return wrapExecution - } - override fun validate(code: String): Exception? { - shell.parse(wrapCode(code)) - return null + override fun run(code: String): Any? { + val wrapExecution = wrapExecution { + try { + val script: Script = shell.parse(wrapCode(code)) + script.run() + } catch (e: CompilationFailedException) { + throw e + } } + return wrapExecution + } + + override fun validate(code: String): Exception? { + shell.parse(wrapCode(code)) + return null + } } diff --git a/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt b/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt index 07376453..96ff69ac 100644 --- a/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt +++ b/groovy/src/test/kotlin/com/simiacryptus/skyenet/groovy/GroovyInterpreterTest.kt @@ -5,8 +5,8 @@ package com.simiacryptus.skyenet.groovy import com.simiacryptus.skyenet.interpreter.InterpreterTestBase class GroovyInterpreterTest : InterpreterTestBase() { - override fun newInterpreter(map: Map) = - GroovyInterpreter(map.map { it.key to it.value as Object }.toMap().toJavaMap()) + override fun newInterpreter(map: Map) = + GroovyInterpreter(map.map { it.key to it.value as Object }.toMap().toJavaMap()) } diff --git a/kotlin/build.gradle.kts b/kotlin/build.gradle.kts index f6a54db8..5def8f8e 100644 --- a/kotlin/build.gradle.kts +++ b/kotlin/build.gradle.kts @@ -6,158 +6,158 @@ group = properties("libraryGroup") version = properties("libraryVersion") plugins { - java - `java-library` - id("org.jetbrains.kotlin.jvm") version "2.0.20" - `maven-publish` - id("signing") + java + `java-library` + id("org.jetbrains.kotlin.jvm") version "2.0.20" + `maven-publish` + id("signing") } repositories { - mavenCentral { - metadataSources { - mavenPom() - artifact() - } + mavenCentral { + metadataSources { + mavenPom() + artifact() } + } } kotlin { - jvmToolchain(11) + jvmToolchain(11) } dependencies { - implementation(project(":core")) + implementation(project(":core")) - //compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC")// https://mvnrepository.com/artifact/org.jetbrains.kotlinx/kotlinx-coroutines-core - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3") + //compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC")// https://mvnrepository.com/artifact/org.jetbrains.kotlinx/kotlinx-coroutines-core + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3") - implementation(kotlin("stdlib")) - implementation(kotlin("scripting-jsr223")) - implementation(kotlin("scripting-jvm")) - implementation(kotlin("scripting-jvm-host")) - implementation(kotlin("script-runtime")) - implementation(kotlin("scripting-compiler-embeddable")) - implementation(kotlin("compiler-embeddable")) + implementation(kotlin("stdlib")) + implementation(kotlin("scripting-jsr223")) + implementation(kotlin("scripting-jvm")) + implementation(kotlin("scripting-jvm-host")) + implementation(kotlin("script-runtime")) + implementation(kotlin("scripting-compiler-embeddable")) + implementation(kotlin("compiler-embeddable")) - implementation(group = "commons-io", name = "commons-io", version = "2.15.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") - implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") - testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.5.8") - testImplementation(group = "ch.qos.logback", name = "logback-core", version = "1.5.8") - testImplementation("org.ow2.asm:asm:9.6") + implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") + testImplementation(group = "ch.qos.logback", name = "logback-classic", version = "1.5.8") + testImplementation(group = "ch.qos.logback", name = "logback-core", version = "1.5.8") + testImplementation("org.ow2.asm:asm:9.6") } tasks { - compileKotlin { - compilerOptions { - javaParameters.set(true) - } + compileKotlin { + compilerOptions { + javaParameters.set(true) } - compileTestKotlin { - compilerOptions { - javaParameters.set(true) - } + } + compileTestKotlin { + compilerOptions { + javaParameters.set(true) } - test { - useJUnitPlatform() - systemProperty("surefire.useManifestOnlyJar", "false") - testLogging { - events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - } - jvmArgs( - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED" - ) + } + test { + useJUnitPlatform() + systemProperty("surefire.useManifestOnlyJar", "false") + testLogging { + events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL } + jvmArgs( + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", "java.base/java.util=ALL-UNNAMED", + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + } } val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") - from(tasks.javadoc) + archiveClassifier.set("javadoc") + from(tasks.javadoc) } val sourcesJar by tasks.registering(Jar::class) { - archiveClassifier.set("sources") - from(sourceSets.main.get().allSource) + archiveClassifier.set("sources") + from(sourceSets.main.get().allSource) } publishing { - publications { - create("mavenJava") { - artifactId = "kotlin" - from(components["java"]) - artifact(sourcesJar.get()) - artifact(javadocJar.get()) - versionMapping { - usage("java-api") { - fromResolutionOf("runtimeClasspath") - } - usage("java-runtime") { - fromResolutionResult() - } - } - pom { - name.set("SkyeNet Kotlin Interpreter") - description.set("A very helpful puppy") - url.set("https://github.com/SimiaCryptus/SkyeNet") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id.set("acharneski") - name.set("Andrew Charneski") - email.set("acharneski@gmail.com") - } - } - scm { - connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") - developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") - url.set("https://github.com/SimiaCryptus/SkyeNet") - } - } + publications { + create("mavenJava") { + artifactId = "kotlin" + from(components["java"]) + artifact(sourcesJar.get()) + artifact(javadocJar.get()) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") } - } - repositories { - maven { - val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" - val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" - url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - credentials { - username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") - ?: properties("ossrhUsername") - password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") - ?: properties("ossrhPassword") - } + usage("java-runtime") { + fromResolutionResult() } - } - if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { - signing { - sign(publications["mavenJava"]) + } + pom { + name.set("SkyeNet Kotlin Interpreter") + description.set("A very helpful puppy") + url.set("https://github.com/SimiaCryptus/SkyeNet") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } } + developers { + developer { + id.set("acharneski") + name.set("Andrew Charneski") + email.set("acharneski@gmail.com") + } + } + scm { + connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") + developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") + url.set("https://github.com/SimiaCryptus/SkyeNet") + } + } + } + } + repositories { + maven { + val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" + url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + credentials { + username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") + ?: properties("ossrhUsername") + password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") + ?: properties("ossrhPassword") + } } + } + if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { + signing { + sign(publications["mavenJava"]) + } + } } if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) { - apply() - configure { - useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - sign(configurations.archives.get()) - } + apply() + configure { + useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + sign(configurations.archives.get()) + } } diff --git a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt index d0036699..0f8f08f9 100644 --- a/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt +++ b/kotlin/src/main/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreter.kt @@ -20,130 +20,130 @@ import kotlin.script.experimental.jvm.util.scriptCompilationClasspathFromContext import kotlin.script.experimental.jvmhost.jsr223.KotlinJsr223ScriptEngineImpl open class KotlinInterpreter( - val defs: Map = mapOf(), + val defs: Map = mapOf(), ) : Interpreter { - final override fun getLanguage(): String = "Kotlin" - override fun getSymbols() = defs + final override fun getLanguage(): String = "Kotlin" + override fun getSymbols() = defs - open val scriptEngine: KotlinJsr223JvmScriptEngineBase - get() = object : KotlinJsr223JvmScriptEngineFactoryBase() { - override fun getScriptEngine() = KotlinJsr223ScriptEngineImpl( - this, - KotlinJsr223DefaultScriptCompilationConfiguration.with { - classLoader?.also { classLoader -> - jvm { - updateClasspath( - scriptCompilationClasspathFromContext( - classLoader = classLoader, - wholeClasspath = true, - unpackJarCollections = false - ) - ) - } - } - }, - KotlinJsr223DefaultScriptEvaluationConfiguration.with { - this.enableScriptsInstancesSharing() - } - ) { - ScriptArgsWithTypes( - arrayOf(), - arrayOf() + open val scriptEngine: KotlinJsr223JvmScriptEngineBase + get() = object : KotlinJsr223JvmScriptEngineFactoryBase() { + override fun getScriptEngine() = KotlinJsr223ScriptEngineImpl( + this, + KotlinJsr223DefaultScriptCompilationConfiguration.with { + classLoader?.also { classLoader -> + jvm { + updateClasspath( + scriptCompilationClasspathFromContext( + classLoader = classLoader, + wholeClasspath = true, + unpackJarCollections = false ) - }.apply { - getBindings(ScriptContext.ENGINE_SCOPE).putAll(getSymbols()) + ) } - }.scriptEngine - - override fun validate(code: String): Throwable? { - val wrappedCode = wrapCode(code) - return try { - scriptEngine.compile(wrappedCode) - null - } catch (ex: ScriptException) { - wrapException(ex, wrappedCode, code) - } catch (ex: Throwable) { - CodingActor.FailedToImplementException( - cause = ex, - language = "Kotlin", - code = code, - ) + } + }, + KotlinJsr223DefaultScriptEvaluationConfiguration.with { + this.enableScriptsInstancesSharing() } + ) { + ScriptArgsWithTypes( + arrayOf(), + arrayOf() + ) + }.apply { + getBindings(ScriptContext.ENGINE_SCOPE).putAll(getSymbols()) + } + }.scriptEngine + + override fun validate(code: String): Throwable? { + val wrappedCode = wrapCode(code) + return try { + scriptEngine.compile(wrappedCode) + null + } catch (ex: ScriptException) { + wrapException(ex, wrappedCode, code) + } catch (ex: Throwable) { + CodingActor.FailedToImplementException( + cause = ex, + language = "Kotlin", + code = code, + ) } + } - override fun run(code: String): Any? { - val wrappedCode = wrapCode(code) - log.debug( - """ + override fun run(code: String): Any? { + val wrappedCode = wrapCode(code) + log.debug( + """ |Running: | ${wrappedCode.trimIndent().replace("\n", "\n\t")} |""".trimMargin().trim() - ) - val bindings: Bindings? - val compile: CompiledScript - val scriptEngine: KotlinJsr223JvmScriptEngineBase - try { - scriptEngine = this.scriptEngine - compile = scriptEngine.compile(wrappedCode) - bindings = scriptEngine.getBindings(ScriptContext.ENGINE_SCOPE) - return kotlinx.coroutines.runBlocking { compile.eval(bindings) } - } catch (ex: ScriptException) { - throw wrapException(ex, wrappedCode, code) - } catch (ex: Throwable) { - throw CodingActor.FailedToImplementException( - cause = ex, - language = "Kotlin", - code = code, - ) - } + ) + val bindings: Bindings? + val compile: CompiledScript + val scriptEngine: KotlinJsr223JvmScriptEngineBase + try { + scriptEngine = this.scriptEngine + compile = scriptEngine.compile(wrappedCode) + bindings = scriptEngine.getBindings(ScriptContext.ENGINE_SCOPE) + return kotlinx.coroutines.runBlocking { compile.eval(bindings) } + } catch (ex: ScriptException) { + throw wrapException(ex, wrappedCode, code) + } catch (ex: Throwable) { + throw CodingActor.FailedToImplementException( + cause = ex, + language = "Kotlin", + code = code, + ) } + } - protected open fun wrapException( - cause: ScriptException, - wrappedCode: String, - code: String - ): CodingActor.FailedToImplementException { - var lineNumber = cause.lineNumber - var column = cause.columnNumber - if (lineNumber == -1 && column == -1) { - val match = Regex("\\(.*:(\\d+):(\\d+)\\)").find(cause.message ?: "") - if (match != null) { - lineNumber = match.groupValues[1].toInt() - column = match.groupValues[2].toInt() - } - } - return CodingActor.FailedToImplementException( - cause = cause, - message = errorMessage( - code = wrappedCode, - line = lineNumber, - column = column, - message = cause.message ?: "" - ), - language = "Kotlin", - code = code, - ) + protected open fun wrapException( + cause: ScriptException, + wrappedCode: String, + code: String + ): CodingActor.FailedToImplementException { + var lineNumber = cause.lineNumber + var column = cause.columnNumber + if (lineNumber == -1 && column == -1) { + val match = Regex("\\(.*:(\\d+):(\\d+)\\)").find(cause.message ?: "") + if (match != null) { + lineNumber = match.groupValues[1].toInt() + column = match.groupValues[2].toInt() + } } + return CodingActor.FailedToImplementException( + cause = cause, + message = errorMessage( + code = wrappedCode, + line = lineNumber, + column = column, + message = cause.message ?: "" + ), + language = "Kotlin", + code = code, + ) + } - override fun wrapCode(code: String): String { - val out = ArrayList() - val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") } - out.addAll(imports) - out.addAll(otherCode) - return out.joinToString("\n") - } + override fun wrapCode(code: String): String { + val out = ArrayList() + val (imports, otherCode) = code.split("\n").partition { it.trim().startsWith("import ") } + out.addAll(imports) + out.addAll(otherCode) + return out.joinToString("\n") + } - companion object { - private val log = LoggerFactory.getLogger(KotlinInterpreter::class.java) + companion object { + private val log = LoggerFactory.getLogger(KotlinInterpreter::class.java) - fun errorMessage( - code: String, - line: Int, - column: Int, - message: String - ) = """ + fun errorMessage( + code: String, + line: Int, + column: Int, + message: String + ) = """ |```text |$message at line ${line} column ${column} | ${if (line < 0) "" else code.split("\n")[line - 1]} @@ -151,8 +151,8 @@ open class KotlinInterpreter( |``` """.trimMargin().trim() - // TODO: Make this threadlocal with wrapper methods - var classLoader: ClassLoader? = KotlinInterpreter::class.java.classLoader + // TODO: Make this threadlocal with wrapper methods + var classLoader: ClassLoader? = KotlinInterpreter::class.java.classLoader - } + } } \ No newline at end of file diff --git a/kotlin/src/test/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreterTest.kt b/kotlin/src/test/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreterTest.kt index 6ed7a84b..413e2e55 100644 --- a/kotlin/src/test/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreterTest.kt +++ b/kotlin/src/test/kotlin/com/simiacryptus/skyenet/kotlin/KotlinInterpreterTest.kt @@ -10,39 +10,39 @@ import org.junit.jupiter.api.Test class KotlinInterpreterTest : InterpreterTestBase() { - override fun newInterpreter(map: Map) = KotlinInterpreter(map) - - @Test - fun `test run with kotlin println`() { - val interpreter = newInterpreter(mapOf()) - val result = interpreter.run("""println("Hello World")""") - Assertions.assertEquals(null, result) - } - - @Test - fun `test validate with kotlin println`() { - val interpreter = newInterpreter(mapOf()) - val result = interpreter.validate("""println("Hello World")""") - Assertions.assertEquals(null, result) - } - - @Test - fun `test validate with invalid function`() { - val interpreter = newInterpreter(mapOf()) - @Language("kotlin") val code = """ + override fun newInterpreter(map: Map) = KotlinInterpreter(map) + + @Test + fun `test run with kotlin println`() { + val interpreter = newInterpreter(mapOf()) + val result = interpreter.run("""println("Hello World")""") + Assertions.assertEquals(null, result) + } + + @Test + fun `test validate with kotlin println`() { + val interpreter = newInterpreter(mapOf()) + val result = interpreter.validate("""println("Hello World")""") + Assertions.assertEquals(null, result) + } + + @Test + fun `test validate with invalid function`() { + val interpreter = newInterpreter(mapOf()) + @Language("kotlin") val code = """ fun foo() { functionNotDefined() } """.trimIndent() - // This should fail because functionNotDefined is not defined... - val result = interpreter.validate(code) - Assertions.assertTrue(result is CodingActor.FailedToImplementException) - try { - interpreter.run(code) - Assertions.fail("Expected exception") - } catch (e: Exception) { - Assertions.assertTrue(e is CodingActor.FailedToImplementException) - } + // This should fail because functionNotDefined is not defined... + val result = interpreter.validate(code) + Assertions.assertTrue(result is CodingActor.FailedToImplementException) + try { + interpreter.run(code) + Assertions.fail("Expected exception") + } catch (e: Exception) { + Assertions.assertTrue(e is CodingActor.FailedToImplementException) } + } } \ No newline at end of file diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts index be6dcebd..d3f37a6e 100644 --- a/scala/build.gradle.kts +++ b/scala/build.gradle.kts @@ -6,136 +6,136 @@ group = properties("libraryGroup") version = properties("libraryVersion") plugins { - java - `java-library` - `scala` - `maven-publish` - id("org.jetbrains.kotlin.jvm") version "2.0.20" - id("signing") + java + `java-library` + `scala` + `maven-publish` + id("org.jetbrains.kotlin.jvm") version "2.0.20" + id("signing") } repositories { - mavenCentral { - metadataSources { - mavenPom() - artifact() - } + mavenCentral { + metadataSources { + mavenPom() + artifact() } + } } java { - toolchain { - languageVersion.set(JavaLanguageVersion.of(11)) - } + toolchain { + languageVersion.set(JavaLanguageVersion.of(11)) + } } val scala_version = "2.13.12" dependencies { - implementation(project(":core")) - implementation(group = "org.scala-lang", name = "scala-library", version = scala_version) - implementation(group = "org.scala-lang", name = "scala-compiler", version = scala_version) - implementation(group = "org.scala-lang", name = "scala-reflect", version = scala_version) - implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") + implementation(project(":core")) + implementation(group = "org.scala-lang", name = "scala-library", version = scala_version) + implementation(group = "org.scala-lang", name = "scala-compiler", version = scala_version) + implementation(group = "org.scala-lang", name = "scala-reflect", version = scala_version) + implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") - testImplementation(group = "org.slf4j", name = "slf4j-simple", version = "2.0.16") - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter", version = "5.8.1") - testImplementation(group = "org.scala-lang.modules", name = "scala-java8-compat_2.13", version = "0.9.1") + testImplementation(group = "org.slf4j", name = "slf4j-simple", version = "2.0.16") + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter", version = "5.8.1") + testImplementation(group = "org.scala-lang.modules", name = "scala-java8-compat_2.13", version = "0.9.1") } tasks { - test { - useJUnitPlatform() - systemProperty("surefire.useManifestOnlyJar", "false") - testLogging { - events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - } - jvmArgs( - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED" - ) + test { + useJUnitPlatform() + systemProperty("surefire.useManifestOnlyJar", "false") + testLogging { + events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL } + jvmArgs( + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", "java.base/java.util=ALL-UNNAMED", + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + } } val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") - from(tasks.javadoc) + archiveClassifier.set("javadoc") + from(tasks.javadoc) } val sourcesJar by tasks.registering(Jar::class) { - archiveClassifier.set("sources") - from(sourceSets.main.get().allSource) + archiveClassifier.set("sources") + from(sourceSets.main.get().allSource) } publishing { - publications { - create("mavenJava") { - artifactId = "scala" - from(components["java"]) - artifact(sourcesJar.get()) - artifact(javadocJar.get()) - versionMapping { - usage("java-api") { - fromResolutionOf("runtimeClasspath") - } - usage("java-runtime") { - fromResolutionResult() - } - } - pom { - name.set("SkyeNet Scala Interpreter") - description.set("A very helpful puppy") - url.set("https://github.com/SimiaCryptus/SkyeNet") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id.set("acharneski") - name.set("Andrew Charneski") - email.set("acharneski@gmail.com") - } - } - scm { - connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") - developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") - url.set("https://github.com/SimiaCryptus/SkyeNet") - } - } + publications { + create("mavenJava") { + artifactId = "scala" + from(components["java"]) + artifact(sourcesJar.get()) + artifact(javadocJar.get()) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") } - } - repositories { - maven { - val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" - val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" - url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - credentials { - username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") - ?: properties("ossrhUsername") - password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") - ?: properties("ossrhPassword") - } + usage("java-runtime") { + fromResolutionResult() } - } - if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { - signing { - sign(publications["mavenJava"]) + } + pom { + name.set("SkyeNet Scala Interpreter") + description.set("A very helpful puppy") + url.set("https://github.com/SimiaCryptus/SkyeNet") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } } + developers { + developer { + id.set("acharneski") + name.set("Andrew Charneski") + email.set("acharneski@gmail.com") + } + } + scm { + connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") + developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") + url.set("https://github.com/SimiaCryptus/SkyeNet") + } + } + } + } + repositories { + maven { + val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" + url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + credentials { + username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") + ?: properties("ossrhUsername") + password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") + ?: properties("ossrhPassword") + } } + } + if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { + signing { + sign(publications["mavenJava"]) + } + } } if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) { - apply() - configure { - useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - sign(configurations.archives.get()) - } + apply() + configure { + useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + sign(configurations.archives.get()) + } } diff --git a/webui/build.gradle.kts b/webui/build.gradle.kts index 3f0ec4ff..b87647d6 100644 --- a/webui/build.gradle.kts +++ b/webui/build.gradle.kts @@ -7,26 +7,26 @@ group = properties("libraryGroup") version = properties("libraryVersion") plugins { - java - `java-library` - id("org.jetbrains.kotlin.jvm") version "2.0.20" - `maven-publish` - id("signing") - id("io.freefair.sass-base") version "8.10.2" - id("io.freefair.sass-java") version "8.10.2" + java + `java-library` + id("org.jetbrains.kotlin.jvm") version "2.0.20" + `maven-publish` + id("signing") + id("io.freefair.sass-base") version "8.10.2" + id("io.freefair.sass-java") version "8.10.2" } repositories { - mavenCentral { - metadataSources { - mavenPom() - artifact() - } + mavenCentral { + metadataSources { + mavenPom() + artifact() } + } } kotlin { - jvmToolchain(11) + jvmToolchain(11) // jvmToolchain(17) } @@ -36,187 +36,187 @@ val jackson_version = "2.17.2" dependencies { - implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.1.12") { - exclude(group = "org.slf4j") - } + implementation(group = "com.simiacryptus", name = "jo-penai", version = "1.1.12") { + exclude(group = "org.slf4j") + } - implementation(project(":core")) - implementation(project(":kotlin")) + implementation(project(":core")) + implementation(project(":kotlin")) - implementation("org.apache.pdfbox:pdfbox:2.0.27") - compileOnly("org.seleniumhq.selenium:selenium-chrome-driver:4.16.1") - implementation("org.jsoup:jsoup:1.18.1") + implementation("org.apache.pdfbox:pdfbox:2.0.27") + compileOnly("org.seleniumhq.selenium:selenium-chrome-driver:4.16.1") + implementation("org.jsoup:jsoup:1.18.1") - implementation("com.google.zxing:core:3.5.3") - implementation("com.google.zxing:javase:3.5.3") + implementation("com.google.zxing:core:3.5.3") + implementation("com.google.zxing:javase:3.5.3") - compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") + compileOnly(group = "software.amazon.awssdk", name = "aws-sdk-java", version = "2.27.23") - compileOnly("org.openapitools:openapi-generator:7.3.0") { - exclude(group = "org.slf4j") - } - compileOnly("org.openapitools:openapi-generator-cli:7.3.0") { - exclude(group = "org.slf4j") - } - testRuntimeOnly("org.openapitools:openapi-generator-cli:7.3.0") + compileOnly("org.openapitools:openapi-generator:7.3.0") { + exclude(group = "org.slf4j") + } + compileOnly("org.openapitools:openapi-generator-cli:7.3.0") { + exclude(group = "org.slf4j") + } + testRuntimeOnly("org.openapitools:openapi-generator-cli:7.3.0") - implementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version) - implementation(group = "org.eclipse.jetty", name = "jetty-servlet", version = jetty_version) - implementation(group = "org.eclipse.jetty", name = "jetty-annotations", version = jetty_version) - implementation(group = "org.eclipse.jetty.websocket", name = "websocket-jetty-server", version = jetty_version) - implementation(group = "org.eclipse.jetty.websocket", name = "websocket-jetty-client", version = jetty_version) - implementation(group = "org.eclipse.jetty.websocket", name = "websocket-servlet", version = jetty_version) - implementation(group = "org.eclipse.jetty", name = "jetty-webapp", version = jetty_version) + implementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version) + implementation(group = "org.eclipse.jetty", name = "jetty-servlet", version = jetty_version) + implementation(group = "org.eclipse.jetty", name = "jetty-annotations", version = jetty_version) + implementation(group = "org.eclipse.jetty.websocket", name = "websocket-jetty-server", version = jetty_version) + implementation(group = "org.eclipse.jetty.websocket", name = "websocket-jetty-client", version = jetty_version) + implementation(group = "org.eclipse.jetty.websocket", name = "websocket-servlet", version = jetty_version) + implementation(group = "org.eclipse.jetty", name = "jetty-webapp", version = jetty_version) - implementation(group = "com.vladsch.flexmark", name = "flexmark", version = "0.64.8") - implementation(group = "com.vladsch.flexmark", name = "flexmark-ext-tables", version = "0.64.8") + implementation(group = "com.vladsch.flexmark", name = "flexmark", version = "0.64.8") + implementation(group = "com.vladsch.flexmark", name = "flexmark-ext-tables", version = "0.64.8") - compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") + compileOnly(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.8.0-RC") - compileOnly(kotlin("stdlib")) - testImplementation(kotlin("stdlib")) + compileOnly(kotlin("stdlib")) + testImplementation(kotlin("stdlib")) - testImplementation(project(":groovy")) - testImplementation(project(":kotlin")) - testImplementation(project(":scala")) + testImplementation(project(":groovy")) + testImplementation(project(":kotlin")) + testImplementation(project(":scala")) - implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.3.1") { - exclude(group = "org.slf4j", module = "slf4j-api") - } + implementation(group = "org.apache.httpcomponents.client5", name = "httpclient5", version = "5.3.1") { + exclude(group = "org.slf4j", module = "slf4j-api") + } - implementation(group = "com.fasterxml.jackson.core", name = "jackson-core", version = jackson_version) - implementation(group = "com.fasterxml.jackson.core", name = "jackson-databind", version = jackson_version) - implementation(group = "com.fasterxml.jackson.core", name = "jackson-annotations", version = jackson_version) - implementation(group = "com.fasterxml.jackson.module", name = "jackson-module-kotlin", version = jackson_version) + implementation(group = "com.fasterxml.jackson.core", name = "jackson-core", version = jackson_version) + implementation(group = "com.fasterxml.jackson.core", name = "jackson-databind", version = jackson_version) + implementation(group = "com.fasterxml.jackson.core", name = "jackson-annotations", version = jackson_version) + implementation(group = "com.fasterxml.jackson.module", name = "jackson-module-kotlin", version = jackson_version) - implementation(group = "com.google.guava", name = "guava", version = "32.1.3-jre") + implementation(group = "com.google.guava", name = "guava", version = "32.1.3-jre") // implementation(group = "com.google.apis", name = "google-api-services-customsearch", version = "v1-rev20230702-2.0.0") - compileOnly(group = "com.google.api-client", name = "google-api-client", version = "1.35.2" /*"1.35.2"*/) - compileOnly(group = "com.google.oauth-client", name = "google-oauth-client-jetty", version = "1.34.1") - compileOnly(group = "com.google.apis", name = "google-api-services-oauth2", version = "v2-rev157-1.25.0") - implementation(group = "commons-io", name = "commons-io", version = "2.15.0") - implementation(group = "commons-codec", name = "commons-codec", version = "1.16.0") - - implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") - runtimeOnly(group = "ch.qos.logback", name = "logback-classic", version = "1.5.8") - runtimeOnly(group = "ch.qos.logback", name = "logback-core", version = "1.5.8") - - testImplementation(kotlin("script-runtime")) - testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") - testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") + compileOnly(group = "com.google.api-client", name = "google-api-client", version = "1.35.2" /*"1.35.2"*/) + compileOnly(group = "com.google.oauth-client", name = "google-oauth-client-jetty", version = "1.34.1") + compileOnly(group = "com.google.apis", name = "google-api-services-oauth2", version = "v2-rev157-1.25.0") + implementation(group = "commons-io", name = "commons-io", version = "2.15.0") + implementation(group = "commons-codec", name = "commons-codec", version = "1.16.0") + + implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.16") + runtimeOnly(group = "ch.qos.logback", name = "logback-classic", version = "1.5.8") + runtimeOnly(group = "ch.qos.logback", name = "logback-core", version = "1.5.8") + + testImplementation(kotlin("script-runtime")) + testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = "5.10.1") + testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = "5.10.1") } sass { - omitSourceMapUrl.set(false) - outputStyle.set(OutputStyle.EXPANDED) - sourceMapContents.set(false) - sourceMapEmbed.set(false) - sourceMapEnabled.set(true) + omitSourceMapUrl.set(false) + outputStyle.set(OutputStyle.EXPANDED) + sourceMapContents.set(false) + sourceMapEmbed.set(false) + sourceMapEnabled.set(true) } tasks { - compileKotlin { - compilerOptions { - javaParameters.set(true) - } + compileKotlin { + compilerOptions { + javaParameters.set(true) } - compileTestKotlin { - compilerOptions { - javaParameters.set(true) - } + } + compileTestKotlin { + compilerOptions { + javaParameters.set(true) } - test { - useJUnitPlatform() - systemProperty("surefire.useManifestOnlyJar", "false") - testLogging { - events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - } - jvmArgs( - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED" - ) + } + test { + useJUnitPlatform() + systemProperty("surefire.useManifestOnlyJar", "false") + testLogging { + events(TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL } + jvmArgs( + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", "java.base/java.util=ALL-UNNAMED", + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + } } val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") - from(tasks.javadoc) + archiveClassifier.set("javadoc") + from(tasks.javadoc) } val sourcesJar by tasks.registering(Jar::class) { - archiveClassifier.set("sources") - from(sourceSets.main.get().allSource) + archiveClassifier.set("sources") + from(sourceSets.main.get().allSource) } publishing { - publications { - create("mavenJava") { - artifactId = "webui" - from(components["java"]) - artifact(sourcesJar.get()) - artifact(javadocJar.get()) - versionMapping { - usage("java-api") { - fromResolutionOf("runtimeClasspath") - } - usage("java-runtime") { - fromResolutionResult() - } - } - pom { - name.set("SkyeNet Web Interface") - description.set("A very helpful puppy") - url.set("https://github.com/SimiaCryptus/SkyeNet") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id.set("acharneski") - name.set("Andrew Charneski") - email.set("acharneski@gmail.com") - } - } - scm { - connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") - developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") - url.set("https://github.com/SimiaCryptus/SkyeNet") - } - } + publications { + create("mavenJava") { + artifactId = "webui" + from(components["java"]) + artifact(sourcesJar.get()) + artifact(javadocJar.get()) + versionMapping { + usage("java-api") { + fromResolutionOf("runtimeClasspath") } - } - repositories { - maven { - val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" - val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" - url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - credentials { - username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") - ?: properties("ossrhUsername") - password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") - ?: properties("ossrhPassword") - } + usage("java-runtime") { + fromResolutionResult() } - } - if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { - signing { - sign(publications["mavenJava"]) + } + pom { + name.set("SkyeNet Web Interface") + description.set("A very helpful puppy") + url.set("https://github.com/SimiaCryptus/SkyeNet") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + developers { + developer { + id.set("acharneski") + name.set("Andrew Charneski") + email.set("acharneski@gmail.com") + } + } + scm { + connection.set("scm:git:git://git@github.com/SimiaCryptus/SkyeNet.git") + developerConnection.set("scm:git:ssh://git@github.com/SimiaCryptus/SkyeNet.git") + url.set("https://github.com/SimiaCryptus/SkyeNet") } + } } + } + repositories { + maven { + val releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + val snapshotsRepoUrl = "https://oss.sonatype.org/mask/repositories/snapshots" + url = URI(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + credentials { + username = System.getenv("OSSRH_USERNAME") ?: System.getProperty("ossrhUsername") + ?: properties("ossrhUsername") + password = System.getenv("OSSRH_PASSWORD") ?: System.getProperty("ossrhPassword") + ?: properties("ossrhPassword") + } + } + } + if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) afterEvaluate { + signing { + sign(publications["mavenJava"]) + } + } } if (System.getenv("GPG_PRIVATE_KEY") != null && System.getenv("GPG_PASSPHRASE") != null) { - apply() - configure { - useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - sign(configurations.archives.get()) - } + apply() + configure { + useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + sign(configurations.archives.get()) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt index 7123ff9e..3e7f550b 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyDiffLinks.kt @@ -8,100 +8,100 @@ import com.simiacryptus.skyenet.webui.session.SessionTask import com.simiacryptus.skyenet.webui.session.SocketManagerBase fun SocketManagerBase.addApplyDiffLinks( - code: () -> String, - response: String, - handle: (String) -> Unit, - task: SessionTask, - ui: ApplicationInterface, + code: () -> String, + response: String, + handle: (String) -> Unit, + task: SessionTask, + ui: ApplicationInterface, ): String { - val patch = { code: String, diff: String -> - val isCurlyBalanced = FileValidationUtils.isCurlyBalanced(code) - val isSquareBalanced = FileValidationUtils.isSquareBalanced(code) - val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) - val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) - val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) - var newCode = IterativePatchUtil.applyPatch(code, diff) - newCode = newCode.replace("\r", "") - val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) - val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) - val isParenthesisBalancedNew = FileValidationUtils.isParenthesisBalanced(newCode) - val isQuoteBalancedNew = FileValidationUtils.isQuoteBalanced(newCode) - val isSingleQuoteBalancedNew = FileValidationUtils.isSingleQuoteBalanced(newCode) - val isError = ((isCurlyBalanced && !isCurlyBalancedNew) || - (isSquareBalanced && !isSquareBalancedNew) || - (isParenthesisBalanced && !isParenthesisBalancedNew) || - (isQuoteBalanced && !isQuoteBalancedNew) || - (isSingleQuoteBalanced && !isSingleQuoteBalancedNew)) - PatchResult(newCode, !isError) - } + val patch = { code: String, diff: String -> + val isCurlyBalanced = FileValidationUtils.isCurlyBalanced(code) + val isSquareBalanced = FileValidationUtils.isSquareBalanced(code) + val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) + val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) + val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) + var newCode = IterativePatchUtil.applyPatch(code, diff) + newCode = newCode.replace("\r", "") + val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) + val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) + val isParenthesisBalancedNew = FileValidationUtils.isParenthesisBalanced(newCode) + val isQuoteBalancedNew = FileValidationUtils.isQuoteBalanced(newCode) + val isSingleQuoteBalancedNew = FileValidationUtils.isSingleQuoteBalanced(newCode) + val isError = ((isCurlyBalanced && !isCurlyBalancedNew) || + (isSquareBalanced && !isSquareBalancedNew) || + (isParenthesisBalanced && !isParenthesisBalancedNew) || + (isQuoteBalanced && !isQuoteBalancedNew) || + (isSingleQuoteBalanced && !isSingleQuoteBalancedNew)) + PatchResult(newCode, !isError) + } - val diffPattern = """(?s)(? - val diffVal: String = diffBlock.groupValues[1] - val buttons = ui.newTask(false) - lateinit var hrefLink: StringBuilder - var reverseHrefLink: StringBuilder? = null - hrefLink = buttons.complete(hrefLink("Apply Diff", classname = "href-link cmd-button") { - try { - val newCode = patch(code(), diffVal) - handle(newCode.newCode) - hrefLink.set("""

Diff Applied
""") - buttons.complete() - reverseHrefLink?.clear() - } catch (e: Throwable) { - hrefLink.append("""
Error: ${e.message}
""") - buttons.complete() - task.error(ui, e) - } - })!! - val patch = patch(code(), diffVal).newCode - val test1 = IterativePatchUtil.generatePatch(code().replace("\r", ""), patch) - val patchRev = patch( - code().lines().reversed().joinToString("\n"), - diffVal.lines().reversed().joinToString("\n") - ).newCode - if (patchRev != patch) { - reverseHrefLink = buttons.complete(hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { - try { - val reversedCode = code().lines().reversed().joinToString("\n") - val reversedDiff = diffVal.lines().reversed().joinToString("\n") - val newReversedCode = patch(reversedCode, reversedDiff).newCode - val newCode = newReversedCode.lines().reversed().joinToString("\n") - handle(newCode) - reverseHrefLink!!.set("""
Diff Applied (Bottom to Top)
""") - buttons.complete() - hrefLink.clear() - } catch (e: Throwable) { - task.error(ui, e) - } - })!! - } - val test2 = DiffUtil.formatDiff( - DiffUtil.generateDiff( - code().lines(), - patchRev.lines().reversed() - ) - ) - val newValue = if (patchRev == patch) { - displayMapInTabs( - mapOf( - "Diff" to renderMarkdown("```diff\n$diffVal\n```", ui = ui, tabs = true), - "Verify" to renderMarkdown("```diff\n$test1\n```", ui = ui, tabs = true), - ), ui = ui, split = true - ) + "\n" + buttons.placeholder - } else { - displayMapInTabs( - mapOf( - "Diff" to renderMarkdown("```diff\n$diffVal\n```", ui = ui, tabs = true), - "Verify" to renderMarkdown("```diff\n$test1\n```", ui = ui, tabs = true), - "Reverse" to renderMarkdown("```diff\n$test2\n```", ui = ui, tabs = true), - ), ui = ui, split = true - ) + "\n" + buttons.placeholder + val diffPattern = """(?s)(? + val diffVal: String = diffBlock.groupValues[1] + val buttons = ui.newTask(false) + lateinit var hrefLink: StringBuilder + var reverseHrefLink: StringBuilder? = null + hrefLink = buttons.complete(hrefLink("Apply Diff", classname = "href-link cmd-button") { + try { + val newCode = patch(code(), diffVal) + handle(newCode.newCode) + hrefLink.set("""
Diff Applied
""") + buttons.complete() + reverseHrefLink?.clear() + } catch (e: Throwable) { + hrefLink.append("""
Error: ${e.message}
""") + buttons.complete() + task.error(ui, e) + } + })!! + val patch = patch(code(), diffVal).newCode + val test1 = IterativePatchUtil.generatePatch(code().replace("\r", ""), patch) + val patchRev = patch( + code().lines().reversed().joinToString("\n"), + diffVal.lines().reversed().joinToString("\n") + ).newCode + if (patchRev != patch) { + reverseHrefLink = buttons.complete(hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { + try { + val reversedCode = code().lines().reversed().joinToString("\n") + val reversedDiff = diffVal.lines().reversed().joinToString("\n") + val newReversedCode = patch(reversedCode, reversedDiff).newCode + val newCode = newReversedCode.lines().reversed().joinToString("\n") + handle(newCode) + reverseHrefLink!!.set("""
Diff Applied (Bottom to Top)
""") + buttons.complete() + hrefLink.clear() + } catch (e: Throwable) { + task.error(ui, e) } - markdown.replace(diffBlock.value, newValue) + })!! + } + val test2 = DiffUtil.formatDiff( + DiffUtil.generateDiff( + code().lines(), + patchRev.lines().reversed() + ) + ) + val newValue = if (patchRev == patch) { + displayMapInTabs( + mapOf( + "Diff" to renderMarkdown("```diff\n$diffVal\n```", ui = ui, tabs = true), + "Verify" to renderMarkdown("```diff\n$test1\n```", ui = ui, tabs = true), + ), ui = ui, split = true + ) + "\n" + buttons.placeholder + } else { + displayMapInTabs( + mapOf( + "Diff" to renderMarkdown("```diff\n$diffVal\n```", ui = ui, tabs = true), + "Verify" to renderMarkdown("```diff\n$test1\n```", ui = ui, tabs = true), + "Reverse" to renderMarkdown("```diff\n$test2\n```", ui = ui, tabs = true), + ), ui = ui, split = true + ) + "\n" + buttons.placeholder } - return withLinks + markdown.replace(diffBlock.value, newValue) + } + return withLinks } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt index 60bf87e6..d0a3d152 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/AddApplyFileDiffLinks.kt @@ -24,158 +24,162 @@ private fun String.reverseLines(): String = lines().reversed().joinToString("\n" // Main function to add apply file diff links to the response fun SocketManagerBase.addApplyFileDiffLinks( - root: Path, - response: String, - handle: (Map) -> Unit = {}, - ui: ApplicationInterface, - api: API, - shouldAutoApply: (Path) -> Boolean = { false }, - model: ChatModel? = null, + root: Path, + response: String, + handle: (Map) -> Unit = {}, + ui: ApplicationInterface, + api: API, + shouldAutoApply: (Path) -> Boolean = { false }, + model: ChatModel? = null, ): String { - // Check if there's an unclosed code block and close it if necessary - val initiator = "(?s)```\\w*\n".toRegex() - if (response.contains(initiator) && !response.split(initiator, 2)[1].contains("\n```(?![^\n])".toRegex())) { - // Single diff block without the closing ``` due to LLM limitations... add it back and recurse - return addApplyFileDiffLinks( - root, - response + "\n```\n", - handle, - ui, - api, - model = model, - ) + // Check if there's an unclosed code block and close it if necessary + val initiator = "(?s)```\\w*\n".toRegex() + if (response.contains(initiator) && !response.split(initiator, 2)[1].contains("\n```(?![^\n])".toRegex())) { + // Single diff block without the closing ``` due to LLM limitations... add it back and recurse + return addApplyFileDiffLinks( + root, + response + "\n```\n", + handle, + ui, + api, + model = model, + ) + } + val headerPattern = """(? + try { + val header = headers.lastOrNull { it.first.last <= block.range.first } + if (header == null) { + return@filter false + } + val filename = resolve(root, header.second) + !root.resolve(filename).toFile().exists() + } catch (e: Throwable) { + log.info("Error processing code block", e) + false } - val headerPattern = """(? - try { - val header = headers.lastOrNull { it.first.last <= block.range.first } - if (header == null) { - return@filter false - } - val filename = resolve(root, header.second) - !root.resolve(filename).toFile().exists() - } catch (e: Throwable) { - log.info("Error processing code block", e) - false - } - }.map { it.range to it }.toList() - val patchBlocks = findAll.filter { block -> - try { - val header = headers.lastOrNull { it.first.last <= block.range.first } - if (header == null) { - return@filter false - } - val filename = resolve(root, header.second) - root.resolve(filename).toFile().exists() - } catch (e: Throwable) { - log.info("Error processing code block", e) - false - } - }.map { it.range to it }.toList() - - // Get - val changes = patchBlocks.mapIndexed { index, it -> - PatchOrCode( - id = "patch_" + index.toString(), - type = "patch", - data = it.second.groupValues[2] - ) - } + codeblocks.mapIndexed { index, it -> - PatchOrCode( - id = "code_" + index.toString(), - type = "code", - data = it.second.groupValues[2] - ) + }.map { it.range to it }.toList() + val patchBlocks = findAll.filter { block -> + try { + val header = headers.lastOrNull { it.first.last <= block.range.first } + if (header == null) { + return@filter false + } + val filename = resolve(root, header.second) + root.resolve(filename).toFile().exists() + } catch (e: Throwable) { + log.info("Error processing code block", e) + false } - val corrections = if(model == null) null else try { - ParsedActor( - resultClass = CorrectedPatchAndCodeList::class.java, - exampleInstance = CorrectedPatchAndCodeList(listOf( - CorrectedPatchOrCode("patch_0", "src/utils/exampleUtils.js"), - CorrectedPatchOrCode("code_0", "src/utils/exampleUtils.js"), - CorrectedPatchOrCode("patch_1", "tests/exampleUtils.test.js"), - CorrectedPatchOrCode("code_1", "tests/exampleUtils.test.js"), - )), - prompt = """ + }.map { it.range to it }.toList() + + // Get + val changes = patchBlocks.mapIndexed { index, it -> + PatchOrCode( + id = "patch_" + index.toString(), + type = "patch", + data = it.second.groupValues[2] + ) + } + codeblocks.mapIndexed { index, it -> + PatchOrCode( + id = "code_" + index.toString(), + type = "code", + data = it.second.groupValues[2] + ) + } + val corrections = if (model == null) null else try { + ParsedActor( + resultClass = CorrectedPatchAndCodeList::class.java, + exampleInstance = CorrectedPatchAndCodeList( + listOf( + CorrectedPatchOrCode("patch_0", "src/utils/exampleUtils.js"), + CorrectedPatchOrCode("code_0", "src/utils/exampleUtils.js"), + CorrectedPatchOrCode("patch_1", "tests/exampleUtils.test.js"), + CorrectedPatchOrCode("code_1", "tests/exampleUtils.test.js"), + ) + ), + prompt = """ Review and correct the file path assignments for the following patches and code blocks. """.trimIndent(), - model = model, - temperature = 0.0, - parsingModel = model, - ).getParser(api).apply(listOf( - response, - JsonUtil.toJson( - PatchAndCodeList( - changes = changes - ) - ) - ).joinToString("\n\n")).changes?.associateBy { it.id }?.mapValues { it.value.filename } ?: emptyMap() - } catch (e: Throwable) { - log.error("Error consulting AI for corrections", e) - null - } - - // Process diff blocks and add patch links - val withPatchLinks: String = patchBlocks.foldIndexed(response) { index, markdown, diffBlock -> - val value = diffBlock.second.groupValues[2].trim() - var header = headers.lastOrNull { it.first.last < diffBlock.first.first }?.second ?: "Unknown" - header = corrections?.get("patch_$index") ?: header - val filename = resolve(root, header) - val newValue = renderDiffBlock(root, filename, value, handle, ui, api, shouldAutoApply) - markdown.replace(diffBlock.second.value, newValue) - } - // Process code blocks and add save links - val withSaveLinks = codeblocks.foldIndexed(withPatchLinks) { index, markdown, codeBlock -> - val lang = codeBlock.second.groupValues[1] - val value = codeBlock.second.groupValues[2].trim() - var header = headers.lastOrNull { it.first.last < codeBlock.first.first }?.second - header = corrections?.get("code_$index") ?: header - val newMarkdown = renderNewFile(header, root, ui, shouldAutoApply, value, handle, lang) - markdown.replace(codeBlock.second.value, newMarkdown) - } - return withSaveLinks + model = model, + temperature = 0.0, + parsingModel = model, + ).getParser(api).apply( + listOf( + response, + JsonUtil.toJson( + PatchAndCodeList( + changes = changes + ) + ) + ).joinToString("\n\n") + ).changes?.associateBy { it.id }?.mapValues { it.value.filename } ?: emptyMap() + } catch (e: Throwable) { + log.error("Error consulting AI for corrections", e) + null + } + + // Process diff blocks and add patch links + val withPatchLinks: String = patchBlocks.foldIndexed(response) { index, markdown, diffBlock -> + val value = diffBlock.second.groupValues[2].trim() + var header = headers.lastOrNull { it.first.last < diffBlock.first.first }?.second ?: "Unknown" + header = corrections?.get("patch_$index") ?: header + val filename = resolve(root, header) + val newValue = renderDiffBlock(root, filename, value, handle, ui, api, shouldAutoApply) + markdown.replace(diffBlock.second.value, newValue) + } + // Process code blocks and add save links + val withSaveLinks = codeblocks.foldIndexed(withPatchLinks) { index, markdown, codeBlock -> + val lang = codeBlock.second.groupValues[1] + val value = codeBlock.second.groupValues[2].trim() + var header = headers.lastOrNull { it.first.last < codeBlock.first.first }?.second + header = corrections?.get("code_$index") ?: header + val newMarkdown = renderNewFile(header, root, ui, shouldAutoApply, value, handle, lang) + markdown.replace(codeBlock.second.value, newMarkdown) + } + return withSaveLinks } data class PatchAndCodeList( - val changes: List, + val changes: List, ) data class PatchOrCode( - val id: String? = null, - val type: String? = null, - val filename: String? = null, - val data: String? = null, + val id: String? = null, + val type: String? = null, + val filename: String? = null, + val data: String? = null, ) data class CorrectedPatchAndCodeList( - val changes: List? = null, + val changes: List? = null, ) data class CorrectedPatchOrCode( - val id: String? = null, - val filename: String? = null, + val id: String? = null, + val filename: String? = null, ) private fun SocketManagerBase.renderNewFile( - header: String?, - root: Path, - ui: ApplicationInterface, - shouldAutoApply: (Path) -> Boolean, - codeValue: String, - handle: (Map) -> Unit, - codeLang: String + header: String?, + root: Path, + ui: ApplicationInterface, + shouldAutoApply: (Path) -> Boolean, + codeValue: String, + handle: (Map) -> Unit, + codeLang: String ): String { - val filename = resolve(root, header ?: "Unknown") - val filepath = root.resolve(filename) - if (shouldAutoApply(filepath) && !filepath.toFile().exists()) { - try { - filepath.parent?.toFile()?.mkdirs() - filepath.toFile().writeText(codeValue, Charsets.UTF_8) - handle(mapOf(File(filename).toPath() to codeValue)) - return """ + val filename = resolve(root, header ?: "Unknown") + val filepath = root.resolve(filename) + if (shouldAutoApply(filepath) && !filepath.toFile().exists()) { + try { + filepath.parent?.toFile()?.mkdirs() + filepath.toFile().writeText(codeValue, Charsets.UTF_8) + handle(mapOf(File(filename).toPath() to codeValue)) + return """ ```${codeLang} ${codeValue} ``` @@ -183,226 +187,226 @@ ${codeValue}
Automatically Saved ${filepath}
""" - } catch (e: Throwable) { - return """ + } catch (e: Throwable) { + return """ ```${codeLang} ${codeValue} ```
Error Auto-Saving ${filename}: ${e.message}
""" - } - } else { - val commandTask = ui.newTask(false) - lateinit var hrefLink: StringBuilder - hrefLink = commandTask.complete(hrefLink("Save File", classname = "href-link cmd-button") { - try { - filepath.parent?.toFile()?.mkdirs() - filepath.toFile().writeText(codeValue, Charsets.UTF_8) - handle(mapOf(File(filename).toPath() to codeValue)) - hrefLink.set("""
Saved ${filepath}
""") - commandTask.complete() - } catch (e: Throwable) { - hrefLink.append("""
Error: ${e.message}
""") - commandTask.error(null, e) - } - })!! - return """ + } + } else { + val commandTask = ui.newTask(false) + lateinit var hrefLink: StringBuilder + hrefLink = commandTask.complete(hrefLink("Save File", classname = "href-link cmd-button") { + try { + filepath.parent?.toFile()?.mkdirs() + filepath.toFile().writeText(codeValue, Charsets.UTF_8) + handle(mapOf(File(filename).toPath() to codeValue)) + hrefLink.set("""
Saved ${filepath}
""") + commandTask.complete() + } catch (e: Throwable) { + hrefLink.append("""
Error: ${e.message}
""") + commandTask.error(null, e) + } + })!! + return """ ```${codeLang} ${codeValue} ``` ${commandTask.placeholder} """ - } + } } private val pattern_backticks = "`(.*)`".toRegex() fun resolve(root: Path, filename: String): String { - var filename = filename.trim() - - filename = if (pattern_backticks.containsMatchIn(filename)) { - pattern_backticks.find(filename)!!.groupValues[1] - } else { - filename - } - - filename = try { - val path = File(filename).toPath() - if (root.contains(path)) path.toString().relativizeFrom(root) else filename - } catch (e: Throwable) { - filename - } - - try { - if (!root.resolve(filename).toFile().exists()) { - root.toFile().listFilesRecursively().find { it.toString().replace("\\", "/").endsWith(filename.replace("\\", "/")) } - ?.toString()?.apply { - filename = relativizeFrom(root) - } + var filename = filename.trim() + + filename = if (pattern_backticks.containsMatchIn(filename)) { + pattern_backticks.find(filename)!!.groupValues[1] + } else { + filename + } + + filename = try { + val path = File(filename).toPath() + if (root.contains(path)) path.toString().relativizeFrom(root) else filename + } catch (e: Throwable) { + filename + } + + try { + if (!root.resolve(filename).toFile().exists()) { + root.toFile().listFilesRecursively().find { it.toString().replace("\\", "/").endsWith(filename.replace("\\", "/")) } + ?.toString()?.apply { + filename = relativizeFrom(root) } - } catch (e: Throwable) { - log.error("Error resolving filename", e) } + } catch (e: Throwable) { + log.error("Error resolving filename", e) + } - return filename + return filename } private fun String.relativizeFrom(root: Path) = try { - root.relativize(File(this).toPath()).toString() + root.relativize(File(this).toPath()).toString() } catch (e: Throwable) { - this + this } private fun File.listFilesRecursively(): List { - val files = mutableListOf() - this.listFiles()?.filter { !isGitignore(it.toPath()) }?.forEach { - files.add(it.absoluteFile) - if (it.isDirectory) { - files.addAll(it.listFilesRecursively()) - } + val files = mutableListOf() + this.listFiles()?.filter { !isGitignore(it.toPath()) }?.forEach { + files.add(it.absoluteFile) + if (it.isDirectory) { + files.addAll(it.listFilesRecursively()) } - return files.toTypedArray().toList() + } + return files.toTypedArray().toList() } // Function to render a diff block with apply and revert options private fun SocketManagerBase.renderDiffBlock( - root: Path, - filename: String, - diffVal: String, - handle: (Map) -> Unit, - ui: ApplicationInterface, - api: API?, - shouldAutoApply: (Path) -> Boolean, - model: ChatModel? = null, + root: Path, + filename: String, + diffVal: String, + handle: (Map) -> Unit, + ui: ApplicationInterface, + api: API?, + shouldAutoApply: (Path) -> Boolean, + model: ChatModel? = null, ): String { - val filepath = root.resolve(filename) - val prevCode = load(filepath) + val filepath = root.resolve(filename) + val prevCode = load(filepath) + val relativize = try { + root.relativize(filepath) + } catch (e: Throwable) { + filepath + } + val applydiffTask = ui.newTask(false) + lateinit var hrefLink: StringBuilder + + var newCode = patch(prevCode, diffVal) + val echoDiff = try { + IterativePatchUtil.generatePatch(prevCode, newCode.newCode) + } catch (e: Throwable) { + renderMarkdown("```\n${e.stackTraceToString()}\n```\n", ui = ui) + } + + // Function to create a revert button + fun createRevertButton(filepath: Path, originalCode: String, handle: (Map) -> Unit): String { val relativize = try { - root.relativize(filepath) + root.relativize(filepath) } catch (e: Throwable) { - filepath + filepath } - val applydiffTask = ui.newTask(false) - lateinit var hrefLink: StringBuilder - - var newCode = patch(prevCode, diffVal) - val echoDiff = try { - IterativePatchUtil.generatePatch(prevCode, newCode.newCode) + val revertTask = ui.newTask(false) + lateinit var revertButton: StringBuilder + revertButton = revertTask.complete(hrefLink("Revert", classname = "href-link cmd-button") { + try { + filepath.toFile().writeText(originalCode, Charsets.UTF_8) + handle(mapOf(relativize to originalCode)) + revertButton.set("""
Reverted
""") + revertTask.complete() + } catch (e: Throwable) { + revertButton.append("""
Error: ${e.message}
""") + revertTask.error(null, e) + } + })!! + return revertTask.placeholder + } + + if (echoDiff.isNotBlank() && newCode.isValid && shouldAutoApply(filepath ?: root.resolve(filename))) { + try { + filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) + val originalCode = AtomicReference(prevCode) + handle(mapOf(relativize to newCode.newCode)) + val revertButton = createRevertButton(filepath, originalCode.get(), handle) + return "```diff\n$diffVal\n```\n" + """
Diff Automatically Applied to ${filepath}
""" + revertButton } catch (e: Throwable) { - renderMarkdown("```\n${e.stackTraceToString()}\n```\n", ui = ui) + log.error("Error auto-applying diff", e) + return "```diff\n$diffVal\n```\n" + """
Error Auto-Applying Diff to ${filepath}: ${e.message}
""" } - - // Function to create a revert button - fun createRevertButton(filepath: Path, originalCode: String, handle: (Map) -> Unit): String { - val relativize = try { - root.relativize(filepath) - } catch (e: Throwable) { - filepath - } - val revertTask = ui.newTask(false) - lateinit var revertButton: StringBuilder - revertButton = revertTask.complete(hrefLink("Revert", classname = "href-link cmd-button") { - try { - filepath.toFile().writeText(originalCode, Charsets.UTF_8) - handle(mapOf(relativize to originalCode)) - revertButton.set("""
Reverted
""") - revertTask.complete() - } catch (e: Throwable) { - revertButton.append("""
Error: ${e.message}
""") - revertTask.error(null, e) - } - })!! - return revertTask.placeholder - } - - if (echoDiff.isNotBlank() && newCode.isValid && shouldAutoApply(filepath ?: root.resolve(filename))) { - try { - filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) - val originalCode = AtomicReference(prevCode) - handle(mapOf(relativize to newCode.newCode)) - val revertButton = createRevertButton(filepath, originalCode.get(), handle) - return "```diff\n$diffVal\n```\n" + """
Diff Automatically Applied to ${filepath}
""" + revertButton - } catch (e: Throwable) { - log.error("Error auto-applying diff", e) - return "```diff\n$diffVal\n```\n" + """
Error Auto-Applying Diff to ${filepath}: ${e.message}
""" - } - } - - - val diffTask = ui.newTask(root = false) - diffTask.complete(renderMarkdown("```diff\n$diffVal\n```\n", ui = ui)) - - // Create tasks for displaying code and patch information - val prevCodeTask = ui.newTask(root = false) - val prevCodeTaskSB = prevCodeTask.add("") - val newCodeTask = ui.newTask(root = false) - val newCodeTaskSB = newCodeTask.add("") - val patchTask = ui.newTask(root = false) - val patchTaskSB = patchTask.add("") - val fixTask = ui.newTask(root = false) - val verifyFwdTabs = if (!newCode.isValid) displayMapInTabs( - mapOf( - "Code" to prevCodeTask.placeholder, - "Preview" to newCodeTask.placeholder, - "Echo" to patchTask.placeholder, - "Fix" to fixTask.placeholder, - ) - ) else displayMapInTabs( - mapOf( - "Code" to prevCodeTask.placeholder, - "Preview" to newCodeTask.placeholder, - "Echo" to patchTask.placeholder, - ) + } + + + val diffTask = ui.newTask(root = false) + diffTask.complete(renderMarkdown("```diff\n$diffVal\n```\n", ui = ui)) + + // Create tasks for displaying code and patch information + val prevCodeTask = ui.newTask(root = false) + val prevCodeTaskSB = prevCodeTask.add("") + val newCodeTask = ui.newTask(root = false) + val newCodeTaskSB = newCodeTask.add("") + val patchTask = ui.newTask(root = false) + val patchTaskSB = patchTask.add("") + val fixTask = ui.newTask(root = false) + val verifyFwdTabs = if (!newCode.isValid) displayMapInTabs( + mapOf( + "Code" to prevCodeTask.placeholder, + "Preview" to newCodeTask.placeholder, + "Echo" to patchTask.placeholder, + "Fix" to fixTask.placeholder, ) - - - val prevCode2Task = ui.newTask(root = false) - val prevCode2TaskSB = prevCode2Task.add("") - val newCode2Task = ui.newTask(root = false) - val newCode2TaskSB = newCode2Task.add("") - val patch2Task = ui.newTask(root = false) - val patch2TaskSB = patch2Task.add("") - val verifyRevTabs = displayMapInTabs( - mapOf( - "Code" to prevCode2Task.placeholder, - "Preview" to newCode2Task.placeholder, - "Echo" to patch2Task.placeholder, - ) + ) else displayMapInTabs( + mapOf( + "Code" to prevCodeTask.placeholder, + "Preview" to newCodeTask.placeholder, + "Echo" to patchTask.placeholder, + ) + ) + + + val prevCode2Task = ui.newTask(root = false) + val prevCode2TaskSB = prevCode2Task.add("") + val newCode2Task = ui.newTask(root = false) + val newCode2TaskSB = newCode2Task.add("") + val patch2Task = ui.newTask(root = false) + val patch2TaskSB = patch2Task.add("") + val verifyRevTabs = displayMapInTabs( + mapOf( + "Code" to prevCode2Task.placeholder, + "Preview" to newCode2Task.placeholder, + "Echo" to patch2Task.placeholder, ) + ) - lateinit var revert: String + lateinit var revert: String - var originalCode = prevCode // For reverting changes - // Create "Apply Diff" button - val apply1 = hrefLink("Apply Diff", classname = "href-link cmd-button") { - try { - originalCode = load(filepath) - newCode = patch(originalCode, diffVal) - filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) ?: log.warn("File not found: $filepath") - handle(mapOf(relativize to newCode.newCode)) - hrefLink.set("
Diff Applied
$revert") - applydiffTask.complete() - } catch (e: Throwable) { - hrefLink.append("""
Error: ${e.message}
""") - applydiffTask.error(null, e) - } + var originalCode = prevCode // For reverting changes + // Create "Apply Diff" button + val apply1 = hrefLink("Apply Diff", classname = "href-link cmd-button") { + try { + originalCode = load(filepath) + newCode = patch(originalCode, diffVal) + filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) ?: log.warn("File not found: $filepath") + handle(mapOf(relativize to newCode.newCode)) + hrefLink.set("
Diff Applied
$revert") + applydiffTask.complete() + } catch (e: Throwable) { + hrefLink.append("""
Error: ${e.message}
""") + applydiffTask.error(null, e) } + } - if (echoDiff.isNotBlank()) { + if (echoDiff.isNotBlank()) { - // Add "Fix Patch" button if the patch is not valid - if (!newCode.isValid) { - val fixPatchLink = hrefLink("Fix Patch", classname = "href-link cmd-button") { - try { - val header = fixTask.header("Attempting to fix patch...") + // Add "Fix Patch" button if the patch is not valid + if (!newCode.isValid) { + val fixPatchLink = hrefLink("Fix Patch", classname = "href-link cmd-button") { + try { + val header = fixTask.header("Attempting to fix patch...") - val patchFixer = SimpleActor( - prompt = """ + val patchFixer = SimpleActor( + prompt = """ You are a helpful AI that helps people with coding. Response should use one or more code patches in diff format within ```diff code blocks. @@ -438,19 +442,19 @@ Here are the patches: }); ``` """, - model = OpenAIModels.GPT4o, - temperature = 0.3 - ) - - val echoDiff = try { - IterativePatchUtil.generatePatch(prevCode, newCode.newCode) - } catch (e: Throwable) { - renderMarkdown("```\n${e.stackTraceToString()}\n```\n", ui = ui) - } - - var answer = patchFixer.answer( - listOf( - """ + model = OpenAIModels.GPT4o, + temperature = 0.3 + ) + + val echoDiff = try { + IterativePatchUtil.generatePatch(prevCode, newCode.newCode) + } catch (e: Throwable) { + renderMarkdown("```\n${e.stackTraceToString()}\n```\n", ui = ui) + } + + var answer = patchFixer.answer( + listOf( + """ Code: ```${filename.split('.').lastOrNull() ?: ""} $prevCode @@ -468,255 +472,255 @@ $echoDiff Please provide a fix for the diff above in the form of a diff patch. """ - ), api as OpenAIClient - ) - answer = ui.socketManager?.addApplyFileDiffLinks(root, answer, handle, ui, api, model=model) ?: answer - header?.clear() - fixTask.complete(renderMarkdown(answer)) - } catch (e: Throwable) { - log.error("Error in fix patch", e) - } - } - //apply1 += fixPatchLink - fixTask.complete(fixPatchLink) + ), api as OpenAIClient + ) + answer = ui.socketManager?.addApplyFileDiffLinks(root, answer, handle, ui, api, model = model) ?: answer + header?.clear() + fixTask.complete(renderMarkdown(answer)) + } catch (e: Throwable) { + log.error("Error in fix patch", e) } + } + //apply1 += fixPatchLink + fixTask.complete(fixPatchLink) + } - // Create "Apply Diff (Bottom to Top)" button - val apply2 = hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { - try { - originalCode = load(filepath) - val originalLines = originalCode.reverseLines() - val diffLines = diffVal.reverseLines() - val patch1 = patch(originalLines, diffLines) - val newCode2 = patch1.newCode.reverseLines() - filepath.toFile()?.writeText(newCode2, Charsets.UTF_8) ?: log.warn("File not found: $filepath") - handle(mapOf(relativize to newCode2)) - hrefLink.set("""
Diff Applied (Bottom to Top)
""" + revert) - applydiffTask.complete() - } catch (e: Throwable) { - hrefLink.append("""
Error: ${e.message}
""") - applydiffTask.error(null, e) - } - } - // Create "Revert" button - revert = hrefLink("Revert", classname = "href-link cmd-button") { - try { - filepath.toFile()?.writeText(originalCode, Charsets.UTF_8) - handle(mapOf(relativize to originalCode)) - hrefLink.set("""
Reverted
""" + apply1 + apply2) - applydiffTask.complete() - } catch (e: Throwable) { - hrefLink.append("""
Error: ${e.message}
""") - applydiffTask.error(null, e) - } - } - hrefLink = applydiffTask.complete(apply1 + "\n" + apply2)!! + // Create "Apply Diff (Bottom to Top)" button + val apply2 = hrefLink("(Bottom to Top)", classname = "href-link cmd-button") { + try { + originalCode = load(filepath) + val originalLines = originalCode.reverseLines() + val diffLines = diffVal.reverseLines() + val patch1 = patch(originalLines, diffLines) + val newCode2 = patch1.newCode.reverseLines() + filepath.toFile()?.writeText(newCode2, Charsets.UTF_8) ?: log.warn("File not found: $filepath") + handle(mapOf(relativize to newCode2)) + hrefLink.set("""
Diff Applied (Bottom to Top)
""" + revert) + applydiffTask.complete() + } catch (e: Throwable) { + hrefLink.append("""
Error: ${e.message}
""") + applydiffTask.error(null, e) + } } + // Create "Revert" button + revert = hrefLink("Revert", classname = "href-link cmd-button") { + try { + filepath.toFile()?.writeText(originalCode, Charsets.UTF_8) + handle(mapOf(relativize to originalCode)) + hrefLink.set("""
Reverted
""" + apply1 + apply2) + applydiffTask.complete() + } catch (e: Throwable) { + hrefLink.append("""
Error: ${e.message}
""") + applydiffTask.error(null, e) + } + } + hrefLink = applydiffTask.complete(apply1 + "\n" + apply2)!! + } - val lang = filename.split('.').lastOrNull() ?: "" - newCodeTaskSB?.set( - renderMarkdown( - """# $filename + val lang = filename.split('.').lastOrNull() ?: "" + newCodeTaskSB?.set( + renderMarkdown( + """# $filename ```$lang ${newCode} ``` """, - ui = ui, tabs = false - ) + ui = ui, tabs = false ) - newCodeTask.complete("") - prevCodeTaskSB?.set( - renderMarkdown( - """# $filename + ) + newCodeTask.complete("") + prevCodeTaskSB?.set( + renderMarkdown( + """# $filename ```$lang ${prevCode} ``` """, - ui = ui, tabs = false - ) + ui = ui, tabs = false ) - prevCodeTask.complete("") - patchTaskSB?.set( - renderMarkdown( - """ + ) + prevCodeTask.complete("") + patchTaskSB?.set( + renderMarkdown( + """ # $filename ```diff ${echoDiff} ``` """, - ui = ui, - tabs = false - ) + ui = ui, + tabs = false ) - patchTask.complete("") - val newCode2 = patch( - load(filepath).reverseLines(), - diffVal.reverseLines() - ).newCode.lines().reversed().joinToString("\n") - val echoDiff2 = try { - IterativePatchUtil.generatePatch(prevCode, newCode2) - } catch (e: Throwable) { - renderMarkdown( - """ + ) + patchTask.complete("") + val newCode2 = patch( + load(filepath).reverseLines(), + diffVal.reverseLines() + ).newCode.lines().reversed().joinToString("\n") + val echoDiff2 = try { + IterativePatchUtil.generatePatch(prevCode, newCode2) + } catch (e: Throwable) { + renderMarkdown( + """ ``` ${e.stackTraceToString()} ``` """, ui = ui - ) - } - newCode2TaskSB?.set( - renderMarkdown( - """ + ) + } + newCode2TaskSB?.set( + renderMarkdown( + """ # $filename ```${filename.split('.').lastOrNull() ?: ""} ${newCode2} ``` """, - ui = ui, tabs = false - ) + ui = ui, tabs = false ) - newCode2Task.complete("") - prevCode2TaskSB?.set( - renderMarkdown( - """ + ) + newCode2Task.complete("") + prevCode2TaskSB?.set( + renderMarkdown( + """ # $filename ```${filename.split('.').lastOrNull() ?: ""} ${prevCode} ``` """, - ui = ui, tabs = false - ) + ui = ui, tabs = false ) - prevCode2Task.complete("") - patch2TaskSB?.set( - renderMarkdown( - """ + ) + prevCode2Task.complete("") + patch2TaskSB?.set( + renderMarkdown( + """ # $filename ```diff ${echoDiff2} ``` """, - ui = ui, - tabs = false - ) + ui = ui, + tabs = false ) - patch2Task.complete("") + ) + patch2Task.complete("") - // Create main tabs for displaying diff and verification information - val mainTabs = displayMapInTabs( + // Create main tabs for displaying diff and verification information + val mainTabs = displayMapInTabs( + mapOf( + "Diff" to diffTask.placeholder, + "Verify" to displayMapInTabs( mapOf( - "Diff" to diffTask.placeholder, - "Verify" to displayMapInTabs( - mapOf( - "Forward" to verifyFwdTabs, - "Reverse" to verifyRevTabs, - ) - ), + "Forward" to verifyFwdTabs, + "Reverse" to verifyRevTabs, ) + ), ) - val newValue = if (newCode.isValid) { - mainTabs + "\n" + applydiffTask.placeholder - } else { - mainTabs + """
Warning: The patch is not valid. Please fix the patch before applying.
""" + applydiffTask.placeholder - } - return newValue + ) + val newValue = if (newCode.isValid) { + mainTabs + "\n" + applydiffTask.placeholder + } else { + mainTabs + """
Warning: The patch is not valid. Please fix the patch before applying.
""" + applydiffTask.placeholder + } + return newValue } // Function to apply a patch to a code string private val patch = { code: String, diff: String -> - val isCurlyBalanced = FileValidationUtils.isCurlyBalanced(code) - val isSquareBalanced = FileValidationUtils.isSquareBalanced(code) - val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) - val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) - val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) - var newCode = IterativePatchUtil.applyPatch(code, diff) - newCode = newCode.replace("\r", "") - val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) - val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) - val isParenthesisBalancedNew = FileValidationUtils.isParenthesisBalanced(newCode) - val isQuoteBalancedNew = FileValidationUtils.isQuoteBalanced(newCode) - val isSingleQuoteBalancedNew = FileValidationUtils.isSingleQuoteBalanced(newCode) - val isError = ((isCurlyBalanced && !isCurlyBalancedNew) || - (isSquareBalanced && !isSquareBalancedNew) || - (isParenthesisBalanced && !isParenthesisBalancedNew) || - (isQuoteBalanced && !isQuoteBalancedNew) || - (isSingleQuoteBalanced && !isSingleQuoteBalancedNew)) - PatchResult(newCode, !isError) + val isCurlyBalanced = FileValidationUtils.isCurlyBalanced(code) + val isSquareBalanced = FileValidationUtils.isSquareBalanced(code) + val isParenthesisBalanced = FileValidationUtils.isParenthesisBalanced(code) + val isQuoteBalanced = FileValidationUtils.isQuoteBalanced(code) + val isSingleQuoteBalanced = FileValidationUtils.isSingleQuoteBalanced(code) + var newCode = IterativePatchUtil.applyPatch(code, diff) + newCode = newCode.replace("\r", "") + val isCurlyBalancedNew = FileValidationUtils.isCurlyBalanced(newCode) + val isSquareBalancedNew = FileValidationUtils.isSquareBalanced(newCode) + val isParenthesisBalancedNew = FileValidationUtils.isParenthesisBalanced(newCode) + val isQuoteBalancedNew = FileValidationUtils.isQuoteBalanced(newCode) + val isSingleQuoteBalancedNew = FileValidationUtils.isSingleQuoteBalanced(newCode) + val isError = ((isCurlyBalanced && !isCurlyBalancedNew) || + (isSquareBalanced && !isSquareBalancedNew) || + (isParenthesisBalanced && !isParenthesisBalancedNew) || + (isQuoteBalanced && !isQuoteBalancedNew) || + (isSingleQuoteBalanced && !isSingleQuoteBalancedNew)) + PatchResult(newCode, !isError) } // Function to load file contents private fun load( - filepath: Path? + filepath: Path? ) = try { - if (true != filepath?.toFile()?.exists()) { - log.warn("File not found: $filepath") - "" - } else { - filepath.readText(Charsets.UTF_8) - } -} catch (e: Throwable) { - log.error("Error reading file: $filepath", e) + if (true != filepath?.toFile()?.exists()) { + log.warn("File not found: $filepath") "" + } else { + filepath.readText(Charsets.UTF_8) + } +} catch (e: Throwable) { + log.error("Error reading file: $filepath", e) + "" } // Function to apply file diffs from a response string @Suppress("unused") fun applyFileDiffs( - root: Path, - response: String, + root: Path, + response: String, ): String { - val headerPattern = """(?s)(?> = - diffPattern.findAll(response).map { it.range to it.groupValues[1] }.toList() - val codeblocks = codeblockPattern.findAll(response).filter { - when (it.groupValues[1]) { - "diff" -> false - else -> true - } - }.map { it.range to it }.toList() - diffs.forEach { diffBlock -> - val header = headers.lastOrNull { it.first.last < diffBlock.first.first } - val filename = resolve(root, header?.second ?: "Unknown") - val diffVal = diffBlock.second - val filepath = root.resolve(filename) - try { - val originalCode = filepath.readText(Charsets.UTF_8) - val newCode = patch(originalCode, diffVal) - filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) - } catch (e: Throwable) { - log.warn("Error", e) - } + val headerPattern = """(?s)(?> = + diffPattern.findAll(response).map { it.range to it.groupValues[1] }.toList() + val codeblocks = codeblockPattern.findAll(response).filter { + when (it.groupValues[1]) { + "diff" -> false + else -> true } - codeblocks.forEach { codeBlock -> - val header = headers.lastOrNull { it.first.last < codeBlock.first.first } - val filename = resolve(root, header?.second ?: "Unknown") - val filepath: Path? = root.resolve(filename) - val codeValue = codeBlock.second.groupValues[2].trim() - lateinit var hrefLink: StringBuilder - try { - try { - filepath?.toFile()?.writeText(codeValue, Charsets.UTF_8) - } catch (e: Throwable) { - log.error("Error writing file: $filepath", e) - } - hrefLink.set("""
Saved ${filename}
""") - } catch (e: Throwable) { - log.error("Error", e) - } + }.map { it.range to it }.toList() + diffs.forEach { diffBlock -> + val header = headers.lastOrNull { it.first.last < diffBlock.first.first } + val filename = resolve(root, header?.second ?: "Unknown") + val diffVal = diffBlock.second + val filepath = root.resolve(filename) + try { + val originalCode = filepath.readText(Charsets.UTF_8) + val newCode = patch(originalCode, diffVal) + filepath.toFile().writeText(newCode.newCode, Charsets.UTF_8) + } catch (e: Throwable) { + log.warn("Error", e) + } + } + codeblocks.forEach { codeBlock -> + val header = headers.lastOrNull { it.first.last < codeBlock.first.first } + val filename = resolve(root, header?.second ?: "Unknown") + val filepath: Path? = root.resolve(filename) + val codeValue = codeBlock.second.groupValues[2].trim() + lateinit var hrefLink: StringBuilder + try { + try { + filepath?.toFile()?.writeText(codeValue, Charsets.UTF_8) + } catch (e: Throwable) { + log.error("Error writing file: $filepath", e) + } + hrefLink.set("""
Saved ${filename}
""") + } catch (e: Throwable) { + log.error("Error", e) } - return response + } + return response } diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/AddShellExecutionLinks.kt b/webui/src/main/kotlin/com/simiacryptus/diff/AddShellExecutionLinks.kt index 54d1a440..353a0293 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/AddShellExecutionLinks.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/AddShellExecutionLinks.kt @@ -8,35 +8,35 @@ import java.io.InputStreamReader import java.util.* fun SocketManagerBase.addShellExecutionLinks( - response: String, - ui: ApplicationInterface + response: String, + ui: ApplicationInterface ): String { - val shellCodePattern = """(?s)(? - val shellCode = matchResult.groupValues[1] - val executionId = UUID.randomUUID().toString() - val executionTask = ui.newTask(false) - val executeButton = hrefLink("Execute", classname = "href-link cmd-button") { - try { - val process = Runtime.getRuntime().exec(arrayOf("sh", "-c", shellCode)) - val reader = BufferedReader(InputStreamReader(process.inputStream)) - val errorReader = BufferedReader(InputStreamReader(process.errorStream)) - val output = StringBuilder() - var line: String? - while (reader.readLine().also { line = it } != null) { - output.append(line).append("\n") - } - while (errorReader.readLine().also { line = it } != null) { - output.append("Error: ").append(line).append("\n") - } - val exitCode = process.waitFor() - output.append("Exit code: $exitCode") - executionTask.complete(MarkdownUtil.renderMarkdown("```\n$output\n```", ui = ui)) - } catch (e: Throwable) { - executionTask.error(null, e) - } + val shellCodePattern = """(?s)(? + val shellCode = matchResult.groupValues[1] + val executionId = UUID.randomUUID().toString() + val executionTask = ui.newTask(false) + val executeButton = hrefLink("Execute", classname = "href-link cmd-button") { + try { + val process = Runtime.getRuntime().exec(arrayOf("sh", "-c", shellCode)) + val reader = BufferedReader(InputStreamReader(process.inputStream)) + val errorReader = BufferedReader(InputStreamReader(process.errorStream)) + val output = StringBuilder() + var line: String? + while (reader.readLine().also { line = it } != null) { + output.append(line).append("\n") } - """ + while (errorReader.readLine().also { line = it } != null) { + output.append("Error: ").append(line).append("\n") + } + val exitCode = process.waitFor() + output.append("Exit code: $exitCode") + executionTask.complete(MarkdownUtil.renderMarkdown("```\n$output\n```", ui = ui)) + } catch (e: Throwable) { + executionTask.error(null, e) + } + } + """ ```shell $shellCode ``` @@ -45,5 +45,5 @@ $shellCode ${executionTask.placeholder} """ - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/ApxPatchUtil.kt b/webui/src/main/kotlin/com/simiacryptus/diff/ApxPatchUtil.kt index 2416888f..2bfc7f8b 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/ApxPatchUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/ApxPatchUtil.kt @@ -5,120 +5,120 @@ import org.apache.commons.text.similarity.LevenshteinDistance object ApxPatchUtil { - fun patch(source: String, patch: String): String { - val sourceLines = source.lines() - val patchLines = patch.lines() - - // This will hold the final result - val result = mutableListOf() - - // This will keep track of the current line in the source file - var sourceIndex = 0 - - // Process each line in the patch - for (patchLine in patchLines.map { it.trim() }) { - when { - // If the line starts with "---" or "+++", it's a file indicator line, skip it - patchLine.startsWith("---") || patchLine.startsWith("+++") -> continue - - // If the line starts with "@@", it's a hunk header - patchLine.startsWith("@@") -> continue - - // If the line starts with "-", it's a deletion, skip the corresponding source line but otherwise treat it as a context line - patchLine.startsWith("-") -> { - sourceIndex = onDelete(patchLine, sourceIndex, sourceLines, result) - } - - // If the line starts with "+", it's an addition, add it to the result - patchLine.startsWith("+") -> { - result.add(patchLine.substring(1)) - } - - // \d+\: ___ is a line number, strip it - patchLine.matches(Regex("\\d+:.*")) -> { - sourceIndex = onContextLine(patchLine.substringAfter(":"), sourceIndex, sourceLines, result) - } - - // it's a context line, advance the source cursor - else -> { - sourceIndex = onContextLine(patchLine, sourceIndex, sourceLines, result) - } - } - } + fun patch(source: String, patch: String): String { + val sourceLines = source.lines() + val patchLines = patch.lines() + + // This will hold the final result + val result = mutableListOf() + + // This will keep track of the current line in the source file + var sourceIndex = 0 + + // Process each line in the patch + for (patchLine in patchLines.map { it.trim() }) { + when { + // If the line starts with "---" or "+++", it's a file indicator line, skip it + patchLine.startsWith("---") || patchLine.startsWith("+++") -> continue - // Append any remaining lines from the source file - while (sourceIndex < sourceLines.size) { - result.add(sourceLines[sourceIndex]) - sourceIndex++ + // If the line starts with "@@", it's a hunk header + patchLine.startsWith("@@") -> continue + + // If the line starts with "-", it's a deletion, skip the corresponding source line but otherwise treat it as a context line + patchLine.startsWith("-") -> { + sourceIndex = onDelete(patchLine, sourceIndex, sourceLines, result) } - return result.joinToString("\n") - } + // If the line starts with "+", it's an addition, add it to the result + patchLine.startsWith("+") -> { + result.add(patchLine.substring(1)) + } - private fun onDelete( - patchLine: String, - sourceIndex: Int, - sourceLines: List, - result: MutableList - ): Int { - var sourceIndex1 = sourceIndex - val delLine = patchLine.substring(1) - val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, delLine) - if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { - val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch) - result.addAll(contextChunk) - sourceIndex1 = sourceIndexSearch + 1 - } else { - println("Deletion line not found in source file: $delLine") - // Ignore + // \d+\: ___ is a line number, strip it + patchLine.matches(Regex("\\d+:.*")) -> { + sourceIndex = onContextLine(patchLine.substringAfter(":"), sourceIndex, sourceLines, result) } - return sourceIndex1 - } - private fun onContextLine( - patchLine: String, - sourceIndex: Int, - sourceLines: List, - result: MutableList - ): Int { - var sourceIndex1 = sourceIndex - val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, patchLine) - if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { - val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch + 1) - result.addAll(contextChunk) - sourceIndex1 = sourceIndexSearch + 1 - } else { - println("Context line not found in source file: $patchLine") - // Ignore + // it's a context line, advance the source cursor + else -> { + sourceIndex = onContextLine(patchLine, sourceIndex, sourceLines, result) } - return sourceIndex1 + } } - private fun lookAheadFor( - sourceIndex: Int, - sourceLines: List, - patchLine: String - ): Int { - var sourceIndexSearch = sourceIndex - while (sourceIndexSearch < sourceLines.size) { - if (lineMatches(patchLine, sourceLines[sourceIndexSearch++])) return sourceIndexSearch - 1 - } - return -1 + // Append any remaining lines from the source file + while (sourceIndex < sourceLines.size) { + result.add(sourceLines[sourceIndex]) + sourceIndex++ } - private fun lineMatches( - a: String, - b: String, - factor: Double = 0.1, - ): Boolean { - val threshold = (Math.max(a.trim().length, b.trim().length) * factor).toInt() - val levenshteinDistance = LevenshteinDistance(threshold + 1) - val dist = levenshteinDistance.apply(a.trim(), b.trim()) - return if (dist >= 0) { - dist <= threshold - } else { - false - } + return result.joinToString("\n") + } + + private fun onDelete( + patchLine: String, + sourceIndex: Int, + sourceLines: List, + result: MutableList + ): Int { + var sourceIndex1 = sourceIndex + val delLine = patchLine.substring(1) + val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, delLine) + if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { + val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch) + result.addAll(contextChunk) + sourceIndex1 = sourceIndexSearch + 1 + } else { + println("Deletion line not found in source file: $delLine") + // Ignore + } + return sourceIndex1 + } + + private fun onContextLine( + patchLine: String, + sourceIndex: Int, + sourceLines: List, + result: MutableList + ): Int { + var sourceIndex1 = sourceIndex + val sourceIndexSearch = lookAheadFor(sourceIndex1, sourceLines, patchLine) + if (sourceIndexSearch > 0 && sourceIndexSearch + 1 < sourceLines.size) { + val contextChunk = sourceLines.subList(sourceIndex1, sourceIndexSearch + 1) + result.addAll(contextChunk) + sourceIndex1 = sourceIndexSearch + 1 + } else { + println("Context line not found in source file: $patchLine") + // Ignore + } + return sourceIndex1 + } + + private fun lookAheadFor( + sourceIndex: Int, + sourceLines: List, + patchLine: String + ): Int { + var sourceIndexSearch = sourceIndex + while (sourceIndexSearch < sourceLines.size) { + if (lineMatches(patchLine, sourceLines[sourceIndexSearch++])) return sourceIndexSearch - 1 + } + return -1 + } + + private fun lineMatches( + a: String, + b: String, + factor: Double = 0.1, + ): Boolean { + val threshold = (Math.max(a.trim().length, b.trim().length) * factor).toInt() + val levenshteinDistance = LevenshteinDistance(threshold + 1) + val dist = levenshteinDistance.apply(a.trim(), b.trim()) + return if (dist >= 0) { + dist <= threshold + } else { + false } + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/DiffMatchPatch.kt b/webui/src/main/kotlin/com/simiacryptus/diff/DiffMatchPatch.kt index 63cca27e..20e31ef2 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/DiffMatchPatch.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/DiffMatchPatch.kt @@ -22,2400 +22,2400 @@ import kotlin.math.min * Also contains the behaviour settings. */ open class DiffMatchPatch { - // Defaults. - // Set these on your diff_match_patch instance to override the defaults. - /** - * Number of seconds to map a diff before giving up (0 for infinity). - */ - var Diff_Timeout: Float = 1.0f - - /** - * Cost of an empty edit operation in terms of edit characters. - */ - private var Diff_EditCost: Short = 4 - - /** - * At what point is no match declared (0.0 = perfection, 1.0 = very loose). - */ - private var Match_Threshold: Float = 0.5f - - /** - * How far to search for a match (0 = exact location, 1000+ = broad match). - * A match this many characters away from the expected location will add - * 1.0 to the score (0.0 is a perfect match). - */ - private var Match_Distance: Int = 1000 - - /** - * When deleting a large block of text (over ~64 characters), how close do - * the contents have to be to match the expected contents. (0.0 = perfection, - * 1.0 = very loose). Note that Match_Threshold controls how closely the - * end points of a delete need to match. - */ - private var Patch_DeleteThreshold: Float = 0.5f - - /** - * Chunk size for context length. - */ - private var Patch_Margin: Short = 4 - - /** - * The number of bits in an int. - */ - private val Match_MaxBits: Short = 32 - - /** - * Internal class for returning results from diff_linesToChars(). - * Other less paranoid languages just use a three-element array. - */ - protected class LinesToCharsResult( - var chars1: String, var chars2: String, - var lineArray: List - ) - - - // DIFF FUNCTIONS - /** - * The data structure representing a diff is a Linked list of Diff objects: - * {Diff(Operation.DELETE, "Hello"), Diff(Operation.INSERT, "Goodbye"), - * Diff(Operation.EQUAL, " world.")} - * which means: delete "Hello", add "Goodbye" and keep " world." - */ - enum class Operation { - DELETE, INSERT, EQUAL + // Defaults. + // Set these on your diff_match_patch instance to override the defaults. + /** + * Number of seconds to map a diff before giving up (0 for infinity). + */ + var Diff_Timeout: Float = 1.0f + + /** + * Cost of an empty edit operation in terms of edit characters. + */ + private var Diff_EditCost: Short = 4 + + /** + * At what point is no match declared (0.0 = perfection, 1.0 = very loose). + */ + private var Match_Threshold: Float = 0.5f + + /** + * How far to search for a match (0 = exact location, 1000+ = broad match). + * A match this many characters away from the expected location will add + * 1.0 to the score (0.0 is a perfect match). + */ + private var Match_Distance: Int = 1000 + + /** + * When deleting a large block of text (over ~64 characters), how close do + * the contents have to be to match the expected contents. (0.0 = perfection, + * 1.0 = very loose). Note that Match_Threshold controls how closely the + * end points of a delete need to match. + */ + private var Patch_DeleteThreshold: Float = 0.5f + + /** + * Chunk size for context length. + */ + private var Patch_Margin: Short = 4 + + /** + * The number of bits in an int. + */ + private val Match_MaxBits: Short = 32 + + /** + * Internal class for returning results from diff_linesToChars(). + * Other less paranoid languages just use a three-element array. + */ + protected class LinesToCharsResult( + var chars1: String, var chars2: String, + var lineArray: List + ) + + + // DIFF FUNCTIONS + /** + * The data structure representing a diff is a Linked list of Diff objects: + * {Diff(Operation.DELETE, "Hello"), Diff(Operation.INSERT, "Goodbye"), + * Diff(Operation.EQUAL, " world.")} + * which means: delete "Hello", add "Goodbye" and keep " world." + */ + enum class Operation { + DELETE, INSERT, EQUAL + } + + /** + * Find the differences between two texts. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @return Linked List of Diff objects. + */ + /** + * Find the differences between two texts. + * Run a faster, slightly less optimal diff. + * This method allows the 'checklines' of diff_main() to be optional. + * Most of the time checklines is wanted, so default to true. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @return Linked List of Diff objects. + */ + @JvmOverloads + fun diff_main(text1: String?, text2: String?, checklines: Boolean = true): LinkedList { + // Set a deadline by which time the diff must be complete. + val deadline: Long + if (Diff_Timeout <= 0) { + deadline = Long.MAX_VALUE + } else { + deadline = System.currentTimeMillis() + (Diff_Timeout * 1000).toLong() } - - /** - * Find the differences between two texts. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @return Linked List of Diff objects. - */ - /** - * Find the differences between two texts. - * Run a faster, slightly less optimal diff. - * This method allows the 'checklines' of diff_main() to be optional. - * Most of the time checklines is wanted, so default to true. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @return Linked List of Diff objects. - */ - @JvmOverloads - fun diff_main(text1: String?, text2: String?, checklines: Boolean = true): LinkedList { - // Set a deadline by which time the diff must be complete. - val deadline: Long - if (Diff_Timeout <= 0) { - deadline = Long.MAX_VALUE - } else { - deadline = System.currentTimeMillis() + (Diff_Timeout * 1000).toLong() - } - return diff_main(text1, text2, checklines, deadline) + return diff_main(text1, text2, checklines, deadline) + } + + /** + * Find the differences between two texts. Simplifies the problem by + * stripping any common prefix or suffix off the texts before diffing. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @param deadline Time when the diff should be complete by. Used + * internally for recursive calls. Users should set DiffTimeout instead. + * @return Linked List of Diff objects. + */ + fun diff_main(text1: String?, text2: String?, checklines: Boolean, deadline: Long): LinkedList { + // Check for null inputs. + var text1 = text1 + var text2 = text2 + if (text1 == null || text2 == null) { + throw IllegalArgumentException("Null inputs. (diff_main)") } - /** - * Find the differences between two texts. Simplifies the problem by - * stripping any common prefix or suffix off the texts before diffing. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @param deadline Time when the diff should be complete by. Used - * internally for recursive calls. Users should set DiffTimeout instead. - * @return Linked List of Diff objects. - */ - fun diff_main(text1: String?, text2: String?, checklines: Boolean, deadline: Long): LinkedList { - // Check for null inputs. - var text1 = text1 - var text2 = text2 - if (text1 == null || text2 == null) { - throw IllegalArgumentException("Null inputs. (diff_main)") - } - - // Check for equality (speedup). - val diffs: LinkedList - if (text1 == text2) { - diffs = LinkedList() - if (text1.length != 0) { - diffs.add(Diff(Operation.EQUAL, text1)) - } - return diffs - } - - // Trim off common prefix (speedup). - var commonlength = diff_commonPrefix(text1, text2) - val commonprefix = text1.substring(0, commonlength) - text1 = text1.substring(commonlength) - text2 = text2.substring(commonlength) - - // Trim off common suffix (speedup). - commonlength = diff_commonSuffix(text1, text2) - val commonsuffix = text1.substring(text1.length - commonlength) - text1 = text1.substring(0, text1.length - commonlength) - text2 = text2.substring(0, text2.length - commonlength) - - // Compute the diff on the middle block. - diffs = diff_compute(text1, text2, checklines, deadline) - - // Restore the prefix and suffix. - if (commonprefix.length != 0) { - diffs.addFirst(Diff(Operation.EQUAL, commonprefix)) - } - if (commonsuffix.length != 0) { - diffs.addLast(Diff(Operation.EQUAL, commonsuffix)) - } - - diff_cleanupMerge(diffs) - return diffs + // Check for equality (speedup). + val diffs: LinkedList + if (text1 == text2) { + diffs = LinkedList() + if (text1.length != 0) { + diffs.add(Diff(Operation.EQUAL, text1)) + } + return diffs } - /** - * Find the differences between two texts. Assumes that the texts do not - * have any common prefix or suffix. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param checklines Speedup flag. If false, then don't run a - * line-level diff first to identify the changed areas. - * If true, then run a faster slightly less optimal diff. - * @param deadline Time when the diff should be complete by. - * @return Linked List of Diff objects. - */ - private fun diff_compute(text1: String, text2: String, checklines: Boolean, deadline: Long): LinkedList { - var diffs = LinkedList() - - if (text1.length == 0) { - // Just add some text (speedup). - diffs.add(Diff(Operation.INSERT, text2)) - return diffs - } + // Trim off common prefix (speedup). + var commonlength = diff_commonPrefix(text1, text2) + val commonprefix = text1.substring(0, commonlength) + text1 = text1.substring(commonlength) + text2 = text2.substring(commonlength) - if (text2.length == 0) { - // Just delete some text (speedup). - diffs.add(Diff(Operation.DELETE, text1)) - return diffs - } - - val longtext = if (text1.length > text2.length) text1 else text2 - val shorttext = if (text1.length > text2.length) text2 else text1 - val i = longtext.indexOf(shorttext) - if (i != -1) { - // Shorter text is inside the longer text (speedup). - val op = if ((text1.length > text2.length)) Operation.DELETE else Operation.INSERT - diffs.add(Diff(op, longtext.substring(0, i))) - diffs.add(Diff(Operation.EQUAL, shorttext)) - diffs.add(Diff(op, longtext.substring(i + shorttext.length))) - return diffs - } - - if (shorttext.length == 1) { - // Single character string. - // After the previous speedup, the character can't be an equality. - diffs.add(Diff(Operation.DELETE, text1)) - diffs.add(Diff(Operation.INSERT, text2)) - return diffs - } - - // Check to see if the problem can be split in two. - val hm = diff_halfMatch(text1, text2) - if (hm != null) { - // A half-match was found, sort out the return data. - val text1_a = hm[0] - val text1_b = hm[1] - val text2_a = hm[2] - val text2_b = hm[3] - val mid_common = hm[4] - // Send both pairs off for separate processing. - val diffs_a = diff_main( - text1_a, text2_a, - checklines, deadline - ) - val diffs_b = diff_main( - text1_b, text2_b, - checklines, deadline - ) - // Merge the results. - diffs = diffs_a - diffs.add(Diff(Operation.EQUAL, mid_common)) - diffs.addAll(diffs_b) - return diffs - } + // Trim off common suffix (speedup). + commonlength = diff_commonSuffix(text1, text2) + val commonsuffix = text1.substring(text1.length - commonlength) + text1 = text1.substring(0, text1.length - commonlength) + text2 = text2.substring(0, text2.length - commonlength) - if ((checklines && text1.length > 100) && text2.length > 100) { - return diff_lineMode(text1, text2, deadline) - } + // Compute the diff on the middle block. + diffs = diff_compute(text1, text2, checklines, deadline) - return diff_bisect(text1, text2, deadline) + // Restore the prefix and suffix. + if (commonprefix.length != 0) { + diffs.addFirst(Diff(Operation.EQUAL, commonprefix)) } - - /** - * Do a quick line-level diff on both strings, then rediff the parts for - * greater accuracy. - * This speedup can produce non-minimal diffs. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param deadline Time when the diff should be complete by. - * @return Linked List of Diff objects. - */ - private fun diff_lineMode( - text1: String, text2: String, - deadline: Long - ): LinkedList { - // Scan the text on a line-by-line basis first. - var text1 = text1 - var text2 = text2 - val a = diff_linesToChars(text1, text2) - text1 = a.chars1 - text2 = a.chars2 - val linearray = a.lineArray - - val diffs = diff_main(text1, text2, false, deadline) - - // Convert the diff back to original text. - diff_charsToLines(diffs, linearray) - // Eliminate freak matches (e.g. blank lines) - diff_cleanupSemantic(diffs) - - // Rediff any replacement blocks, this time character-by-character. - // Add a dummy entry at the end. - diffs.add(Diff(Operation.EQUAL, "")) - var count_delete = 0 - var count_insert = 0 - var text_delete: String = "" - var text_insert: String = "" - val pointer = diffs.listIterator() - var thisDiff: Diff? = pointer.next() - while (thisDiff != null) { - when (thisDiff.operation) { - Operation.INSERT -> { - count_insert++ - text_insert += thisDiff.text - } - - Operation.DELETE -> { - count_delete++ - text_delete += thisDiff.text - } - - Operation.EQUAL -> { - // Upon reaching an equality, check for prior redundancies. - if (count_delete >= 1 && count_insert >= 1) { - // Delete the offending records and add the merged ones. - pointer.previous() - var j = 0 - while (j < count_delete + count_insert) { - pointer.previous() - pointer.remove() - j++ - } - for (subDiff: Diff in diff_main( - text_delete, text_insert, false, - deadline - )) { - pointer.add(subDiff) - } - } - count_insert = 0 - count_delete = 0 - text_delete = "" - text_insert = "" - } - - null -> TODO() - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - diffs.removeLast() // Remove the dummy entry at the end. - - return diffs + if (commonsuffix.length != 0) { + diffs.addLast(Diff(Operation.EQUAL, commonsuffix)) } - /** - * Find the 'middle snake' of a diff, split the problem in two - * and return the recursively constructed diff. - * See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param deadline Time at which to bail if not yet complete. - * @return LinkedList of Diff objects. - */ - private fun diff_bisect( - text1: String, text2: String, - deadline: Long - ): LinkedList { - // Cache the text lengths to prevent multiple calls. - val text1_length = text1.length - val text2_length = text2.length - val max_d = (text1_length + text2_length + 1) / 2 - val v_offset = max_d - val v_length = 2 * max_d - val v1 = IntArray(v_length) - val v2 = IntArray(v_length) - for (x in 0 until v_length) { - v1[x] = -1 - v2[x] = -1 - } - v1[v_offset + 1] = 0 - v2[v_offset + 1] = 0 - val delta = text1_length - text2_length - // If the total number of characters is odd, then the front path will - // collide with the reverse path. - val front = (delta % 2 != 0) - // Offsets for start and end of k loop. - // Prevents mapping of space beyond the grid. - var k1start = 0 - var k1end = 0 - var k2start = 0 - var k2end = 0 - for (d in 0 until max_d) { - // Bail out if deadline is reached. - if (System.currentTimeMillis() > deadline) { - break - } - - // Walk the front path one step. - var k1 = -d + k1start - while (k1 <= d - k1end) { - val k1_offset = v_offset + k1 - var x1: Int - if (k1 == -d || (k1 != d && v1[k1_offset - 1] < v1[k1_offset + 1])) { - x1 = v1[k1_offset + 1] - } else { - x1 = v1[k1_offset - 1] + 1 - } - var y1 = x1 - k1 - while ((x1 < text1_length) && y1 < text2_length && text1[x1] == text2[y1]) { - x1++ - y1++ - } - v1[k1_offset] = x1 - if (x1 > text1_length) { - // Ran off the right of the graph. - k1end += 2 - } else if (y1 > text2_length) { - // Ran off the bottom of the graph. - k1start += 2 - } else if (front) { - val k2_offset = v_offset + delta - k1 - if ((k2_offset >= 0 && k2_offset < v_length) && v2[k2_offset] != -1) { - // Mirror x2 onto top-left coordinate system. - val x2 = text1_length - v2[k2_offset] - if (x1 >= x2) { - // Overlap detected. - return diff_bisectSplit(text1, text2, x1, y1, deadline) - } - } - } - k1 += 2 - } - - // Walk the reverse path one step. - var k2 = -d + k2start - while (k2 <= d - k2end) { - val k2_offset = v_offset + k2 - var x2: Int - if (k2 == -d || (k2 != d && v2[k2_offset - 1] < v2[k2_offset + 1])) { - x2 = v2[k2_offset + 1] - } else { - x2 = v2[k2_offset - 1] + 1 - } - var y2 = x2 - k2 - while ((x2 < text1_length) && y2 < text2_length && (text1[text1_length - x2 - 1] - == text2[text2_length - y2 - 1]) - ) { - x2++ - y2++ - } - v2[k2_offset] = x2 - if (x2 > text1_length) { - // Ran off the left of the graph. - k2end += 2 - } else if (y2 > text2_length) { - // Ran off the top of the graph. - k2start += 2 - } else if (!front) { - val k1_offset = v_offset + delta - k2 - if (((k1_offset >= 0) && k1_offset < v_length) && v1[k1_offset] != -1) { - val x1 = v1[k1_offset] - val y1 = v_offset + x1 - k1_offset - // Mirror x2 onto top-left coordinate system. - x2 = text1_length - x2 - if (x1 >= x2) { - // Overlap detected. - return diff_bisectSplit(text1, text2, x1, y1, deadline) - } - } - } - k2 += 2 - } - } - // Diff took too long and hit the deadline or - // number of diffs equals number of characters, no commonality at all. - val diffs = LinkedList() - diffs.add(Diff(Operation.DELETE, text1)) - diffs.add(Diff(Operation.INSERT, text2)) - return diffs + diff_cleanupMerge(diffs) + return diffs + } + + /** + * Find the differences between two texts. Assumes that the texts do not + * have any common prefix or suffix. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param checklines Speedup flag. If false, then don't run a + * line-level diff first to identify the changed areas. + * If true, then run a faster slightly less optimal diff. + * @param deadline Time when the diff should be complete by. + * @return Linked List of Diff objects. + */ + private fun diff_compute(text1: String, text2: String, checklines: Boolean, deadline: Long): LinkedList { + var diffs = LinkedList() + + if (text1.length == 0) { + // Just add some text (speedup). + diffs.add(Diff(Operation.INSERT, text2)) + return diffs } - /** - * Given the location of the 'middle snake', split the diff in two parts - * and recurse. - * @param text1 Old string to be diffed. - * @param text2 New string to be diffed. - * @param x Index of split point in text1. - * @param y Index of split point in text2. - * @param deadline Time at which to bail if not yet complete. - * @return LinkedList of Diff objects. - */ - private fun diff_bisectSplit( - text1: String, text2: String, - x: Int, y: Int, deadline: Long - ): LinkedList { - val text1a = text1.substring(0, x) - val text2a = text2.substring(0, y) - val text1b = text1.substring(x) - val text2b = text2.substring(y) - - // Compute both diffs serially. - val diffs = diff_main(text1a, text2a, false, deadline) - val diffsb = diff_main(text1b, text2b, false, deadline) - - diffs.addAll(diffsb) - return diffs + if (text2.length == 0) { + // Just delete some text (speedup). + diffs.add(Diff(Operation.DELETE, text1)) + return diffs } - /** - * Split two texts into a list of strings. Reduce the texts to a string of - * hashes where each Unicode character represents one line. - * @param text1 First string. - * @param text2 Second string. - * @return An object containing the encoded text1, the encoded text2 and - * the List of unique strings. The zeroth element of the List of - * unique strings is intentionally blank. - */ - private fun diff_linesToChars(text1: String, text2: String): LinesToCharsResult { - val lineArray: MutableList = ArrayList() - val lineHash: MutableMap = HashMap() + val longtext = if (text1.length > text2.length) text1 else text2 + val shorttext = if (text1.length > text2.length) text2 else text1 + val i = longtext.indexOf(shorttext) + if (i != -1) { + // Shorter text is inside the longer text (speedup). + val op = if ((text1.length > text2.length)) Operation.DELETE else Operation.INSERT + diffs.add(Diff(op, longtext.substring(0, i))) + diffs.add(Diff(Operation.EQUAL, shorttext)) + diffs.add(Diff(op, longtext.substring(i + shorttext.length))) + return diffs + } - // e.g. linearray[4] == "Hello\n" - // e.g. linehash.get("Hello\n") == 4 + if (shorttext.length == 1) { + // Single character string. + // After the previous speedup, the character can't be an equality. + diffs.add(Diff(Operation.DELETE, text1)) + diffs.add(Diff(Operation.INSERT, text2)) + return diffs + } - // "\x00" is a valid character, but various debuggers don't like it. - // So we'll insert a junk entry to avoid generating a null character. - lineArray.add("") + // Check to see if the problem can be split in two. + val hm = diff_halfMatch(text1, text2) + if (hm != null) { + // A half-match was found, sort out the return data. + val text1_a = hm[0] + val text1_b = hm[1] + val text2_a = hm[2] + val text2_b = hm[3] + val mid_common = hm[4] + // Send both pairs off for separate processing. + val diffs_a = diff_main( + text1_a, text2_a, + checklines, deadline + ) + val diffs_b = diff_main( + text1_b, text2_b, + checklines, deadline + ) + // Merge the results. + diffs = diffs_a + diffs.add(Diff(Operation.EQUAL, mid_common)) + diffs.addAll(diffs_b) + return diffs + } - // Allocate 2/3rds of the space for text1, the rest for text2. - val chars1 = diff_linesToCharsMunge(text1, lineArray, lineHash, 40000) - val chars2 = diff_linesToCharsMunge(text2, lineArray, lineHash, 65535) - return LinesToCharsResult(chars1, chars2, lineArray) + if ((checklines && text1.length > 100) && text2.length > 100) { + return diff_lineMode(text1, text2, deadline) } - /** - * Split a text into a list of strings. Reduce the texts to a string of - * hashes where each Unicode character represents one line. - * @param text String to encode. - * @param lineArray List of unique strings. - * @param lineHash Map of strings to indices. - * @param maxLines Maximum length of lineArray. - * @return Encoded string. - */ - private fun diff_linesToCharsMunge( - text: String, lineArray: MutableList, - lineHash: MutableMap, maxLines: Int - ): String { - var lineStart = 0 - var lineEnd = -1 - var line: String - val chars = StringBuilder() - // Walk the text, pulling out a substring for each line. - // text.split('\n') would would temporarily double our memory footprint. - // Modifying text would create many large strings to garbage collect. - while (lineEnd < text.length - 1) { - lineEnd = text.indexOf('\n', lineStart) - if (lineEnd == -1) { - lineEnd = text.length - 1 + return diff_bisect(text1, text2, deadline) + } + + /** + * Do a quick line-level diff on both strings, then rediff the parts for + * greater accuracy. + * This speedup can produce non-minimal diffs. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param deadline Time when the diff should be complete by. + * @return Linked List of Diff objects. + */ + private fun diff_lineMode( + text1: String, text2: String, + deadline: Long + ): LinkedList { + // Scan the text on a line-by-line basis first. + var text1 = text1 + var text2 = text2 + val a = diff_linesToChars(text1, text2) + text1 = a.chars1 + text2 = a.chars2 + val linearray = a.lineArray + + val diffs = diff_main(text1, text2, false, deadline) + + // Convert the diff back to original text. + diff_charsToLines(diffs, linearray) + // Eliminate freak matches (e.g. blank lines) + diff_cleanupSemantic(diffs) + + // Rediff any replacement blocks, this time character-by-character. + // Add a dummy entry at the end. + diffs.add(Diff(Operation.EQUAL, "")) + var count_delete = 0 + var count_insert = 0 + var text_delete: String = "" + var text_insert: String = "" + val pointer = diffs.listIterator() + var thisDiff: Diff? = pointer.next() + while (thisDiff != null) { + when (thisDiff.operation) { + Operation.INSERT -> { + count_insert++ + text_insert += thisDiff.text + } + + Operation.DELETE -> { + count_delete++ + text_delete += thisDiff.text + } + + Operation.EQUAL -> { + // Upon reaching an equality, check for prior redundancies. + if (count_delete >= 1 && count_insert >= 1) { + // Delete the offending records and add the merged ones. + pointer.previous() + var j = 0 + while (j < count_delete + count_insert) { + pointer.previous() + pointer.remove() + j++ } - line = text.substring(lineStart, lineEnd + 1) - - if (lineHash.containsKey(line)) { - chars.append((lineHash[line] as Int).toChar().toString()) - } else { - if (lineArray.size == maxLines) { - // Bail out at 65535 because - // String.valueOf((char) 65536).equals(String.valueOf(((char) 0))) - line = text.substring(lineStart) - lineEnd = text.length - } - lineArray.add(line) - lineHash[line] = lineArray.size - 1 - chars.append((lineArray.size - 1).toChar().toString()) + for (subDiff: Diff in diff_main( + text_delete, text_insert, false, + deadline + )) { + pointer.add(subDiff) } - lineStart = lineEnd + 1 + } + count_insert = 0 + count_delete = 0 + text_delete = "" + text_insert = "" } - return chars.toString() - } - /** - * Rehydrate the text in a diff from a string of line hashes to real lines of - * text. - * @param diffs List of Diff objects. - * @param lineArray List of unique strings. - */ - private fun diff_charsToLines( - diffs: List, - lineArray: List - ) { - var text: StringBuilder - for (diff: Diff in diffs) { - text = StringBuilder() - for (j in 0 until diff.text!!.length) { - text.append(lineArray[diff.text!![j].code]) - } - diff.text = text.toString() - } + null -> TODO() + } + thisDiff = if (pointer.hasNext()) pointer.next() else null } - - /** - * Determine the common prefix of two strings - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the start of each string. - */ - fun diff_commonPrefix(text1: String?, text2: String?): Int { - // Performance analysis: https://neil.fraser.name/news/2007/10/09/ - val n = min(text1!!.length.toDouble(), text2!!.length.toDouble()).toInt() - for (i in 0 until n) { - if (text1[i] != text2[i]) { - return i + diffs.removeLast() // Remove the dummy entry at the end. + + return diffs + } + + /** + * Find the 'middle snake' of a diff, split the problem in two + * and return the recursively constructed diff. + * See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param deadline Time at which to bail if not yet complete. + * @return LinkedList of Diff objects. + */ + private fun diff_bisect( + text1: String, text2: String, + deadline: Long + ): LinkedList { + // Cache the text lengths to prevent multiple calls. + val text1_length = text1.length + val text2_length = text2.length + val max_d = (text1_length + text2_length + 1) / 2 + val v_offset = max_d + val v_length = 2 * max_d + val v1 = IntArray(v_length) + val v2 = IntArray(v_length) + for (x in 0 until v_length) { + v1[x] = -1 + v2[x] = -1 + } + v1[v_offset + 1] = 0 + v2[v_offset + 1] = 0 + val delta = text1_length - text2_length + // If the total number of characters is odd, then the front path will + // collide with the reverse path. + val front = (delta % 2 != 0) + // Offsets for start and end of k loop. + // Prevents mapping of space beyond the grid. + var k1start = 0 + var k1end = 0 + var k2start = 0 + var k2end = 0 + for (d in 0 until max_d) { + // Bail out if deadline is reached. + if (System.currentTimeMillis() > deadline) { + break + } + + // Walk the front path one step. + var k1 = -d + k1start + while (k1 <= d - k1end) { + val k1_offset = v_offset + k1 + var x1: Int + if (k1 == -d || (k1 != d && v1[k1_offset - 1] < v1[k1_offset + 1])) { + x1 = v1[k1_offset + 1] + } else { + x1 = v1[k1_offset - 1] + 1 + } + var y1 = x1 - k1 + while ((x1 < text1_length) && y1 < text2_length && text1[x1] == text2[y1]) { + x1++ + y1++ + } + v1[k1_offset] = x1 + if (x1 > text1_length) { + // Ran off the right of the graph. + k1end += 2 + } else if (y1 > text2_length) { + // Ran off the bottom of the graph. + k1start += 2 + } else if (front) { + val k2_offset = v_offset + delta - k1 + if ((k2_offset >= 0 && k2_offset < v_length) && v2[k2_offset] != -1) { + // Mirror x2 onto top-left coordinate system. + val x2 = text1_length - v2[k2_offset] + if (x1 >= x2) { + // Overlap detected. + return diff_bisectSplit(text1, text2, x1, y1, deadline) } + } + } + k1 += 2 + } + + // Walk the reverse path one step. + var k2 = -d + k2start + while (k2 <= d - k2end) { + val k2_offset = v_offset + k2 + var x2: Int + if (k2 == -d || (k2 != d && v2[k2_offset - 1] < v2[k2_offset + 1])) { + x2 = v2[k2_offset + 1] + } else { + x2 = v2[k2_offset - 1] + 1 } - return n - } - - /** - * Determine the common suffix of two strings - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the end of each string. - */ - fun diff_commonSuffix(text1: String?, text2: String?): Int { - // Performance analysis: https://neil.fraser.name/news/2007/10/09/ - val text1_length = text1!!.length - val text2_length = text2!!.length - val n = min(text1_length.toDouble(), text2_length.toDouble()).toInt() - for (i in 1..n) { - if (text1[text1_length - i] != text2[text2_length - i]) { - return i - 1 + var y2 = x2 - k2 + while ((x2 < text1_length) && y2 < text2_length && (text1[text1_length - x2 - 1] + == text2[text2_length - y2 - 1]) + ) { + x2++ + y2++ + } + v2[k2_offset] = x2 + if (x2 > text1_length) { + // Ran off the left of the graph. + k2end += 2 + } else if (y2 > text2_length) { + // Ran off the top of the graph. + k2start += 2 + } else if (!front) { + val k1_offset = v_offset + delta - k2 + if (((k1_offset >= 0) && k1_offset < v_length) && v1[k1_offset] != -1) { + val x1 = v1[k1_offset] + val y1 = v_offset + x1 - k1_offset + // Mirror x2 onto top-left coordinate system. + x2 = text1_length - x2 + if (x1 >= x2) { + // Overlap detected. + return diff_bisectSplit(text1, text2, x1, y1, deadline) } + } } - return n + k2 += 2 + } + } + // Diff took too long and hit the deadline or + // number of diffs equals number of characters, no commonality at all. + val diffs = LinkedList() + diffs.add(Diff(Operation.DELETE, text1)) + diffs.add(Diff(Operation.INSERT, text2)) + return diffs + } + + /** + * Given the location of the 'middle snake', split the diff in two parts + * and recurse. + * @param text1 Old string to be diffed. + * @param text2 New string to be diffed. + * @param x Index of split point in text1. + * @param y Index of split point in text2. + * @param deadline Time at which to bail if not yet complete. + * @return LinkedList of Diff objects. + */ + private fun diff_bisectSplit( + text1: String, text2: String, + x: Int, y: Int, deadline: Long + ): LinkedList { + val text1a = text1.substring(0, x) + val text2a = text2.substring(0, y) + val text1b = text1.substring(x) + val text2b = text2.substring(y) + + // Compute both diffs serially. + val diffs = diff_main(text1a, text2a, false, deadline) + val diffsb = diff_main(text1b, text2b, false, deadline) + + diffs.addAll(diffsb) + return diffs + } + + /** + * Split two texts into a list of strings. Reduce the texts to a string of + * hashes where each Unicode character represents one line. + * @param text1 First string. + * @param text2 Second string. + * @return An object containing the encoded text1, the encoded text2 and + * the List of unique strings. The zeroth element of the List of + * unique strings is intentionally blank. + */ + private fun diff_linesToChars(text1: String, text2: String): LinesToCharsResult { + val lineArray: MutableList = ArrayList() + val lineHash: MutableMap = HashMap() + + // e.g. linearray[4] == "Hello\n" + // e.g. linehash.get("Hello\n") == 4 + + // "\x00" is a valid character, but various debuggers don't like it. + // So we'll insert a junk entry to avoid generating a null character. + lineArray.add("") + + // Allocate 2/3rds of the space for text1, the rest for text2. + val chars1 = diff_linesToCharsMunge(text1, lineArray, lineHash, 40000) + val chars2 = diff_linesToCharsMunge(text2, lineArray, lineHash, 65535) + return LinesToCharsResult(chars1, chars2, lineArray) + } + + /** + * Split a text into a list of strings. Reduce the texts to a string of + * hashes where each Unicode character represents one line. + * @param text String to encode. + * @param lineArray List of unique strings. + * @param lineHash Map of strings to indices. + * @param maxLines Maximum length of lineArray. + * @return Encoded string. + */ + private fun diff_linesToCharsMunge( + text: String, lineArray: MutableList, + lineHash: MutableMap, maxLines: Int + ): String { + var lineStart = 0 + var lineEnd = -1 + var line: String + val chars = StringBuilder() + // Walk the text, pulling out a substring for each line. + // text.split('\n') would would temporarily double our memory footprint. + // Modifying text would create many large strings to garbage collect. + while (lineEnd < text.length - 1) { + lineEnd = text.indexOf('\n', lineStart) + if (lineEnd == -1) { + lineEnd = text.length - 1 + } + line = text.substring(lineStart, lineEnd + 1) + + if (lineHash.containsKey(line)) { + chars.append((lineHash[line] as Int).toChar().toString()) + } else { + if (lineArray.size == maxLines) { + // Bail out at 65535 because + // String.valueOf((char) 65536).equals(String.valueOf(((char) 0))) + line = text.substring(lineStart) + lineEnd = text.length + } + lineArray.add(line) + lineHash[line] = lineArray.size - 1 + chars.append((lineArray.size - 1).toChar().toString()) + } + lineStart = lineEnd + 1 + } + return chars.toString() + } + + /** + * Rehydrate the text in a diff from a string of line hashes to real lines of + * text. + * @param diffs List of Diff objects. + * @param lineArray List of unique strings. + */ + private fun diff_charsToLines( + diffs: List, + lineArray: List + ) { + var text: StringBuilder + for (diff: Diff in diffs) { + text = StringBuilder() + for (j in 0 until diff.text!!.length) { + text.append(lineArray[diff.text!![j].code]) + } + diff.text = text.toString() + } + } + + /** + * Determine the common prefix of two strings + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the start of each string. + */ + fun diff_commonPrefix(text1: String?, text2: String?): Int { + // Performance analysis: https://neil.fraser.name/news/2007/10/09/ + val n = min(text1!!.length.toDouble(), text2!!.length.toDouble()).toInt() + for (i in 0 until n) { + if (text1[i] != text2[i]) { + return i + } + } + return n + } + + /** + * Determine the common suffix of two strings + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the end of each string. + */ + fun diff_commonSuffix(text1: String?, text2: String?): Int { + // Performance analysis: https://neil.fraser.name/news/2007/10/09/ + val text1_length = text1!!.length + val text2_length = text2!!.length + val n = min(text1_length.toDouble(), text2_length.toDouble()).toInt() + for (i in 1..n) { + if (text1[text1_length - i] != text2[text2_length - i]) { + return i - 1 + } + } + return n + } + + /** + * Determine if the suffix of one string is the prefix of another. + * @param text1 First string. + * @param text2 Second string. + * @return The number of characters common to the end of the first + * string and the start of the second string. + */ + private fun diff_commonOverlap(text1: String?, text2: String?): Int { + // Cache the text lengths to prevent multiple calls. + var text1 = text1 + var text2 = text2 + val text1_length = text1!!.length + val text2_length = text2!!.length + // Eliminate the null case. + if (text1_length == 0 || text2_length == 0) { + return 0 + } + // Truncate the longer string. + if (text1_length > text2_length) { + text1 = text1.substring(text1_length - text2_length) + } else if (text1_length < text2_length) { + text2 = text2.substring(0, text1_length) + } + val text_length = min(text1_length.toDouble(), text2_length.toDouble()).toInt() + // Quick check for the worst case. + if (text1 == text2) { + return text_length } - /** - * Determine if the suffix of one string is the prefix of another. - * @param text1 First string. - * @param text2 Second string. - * @return The number of characters common to the end of the first - * string and the start of the second string. - */ - private fun diff_commonOverlap(text1: String?, text2: String?): Int { - // Cache the text lengths to prevent multiple calls. - var text1 = text1 - var text2 = text2 - val text1_length = text1!!.length - val text2_length = text2!!.length - // Eliminate the null case. - if (text1_length == 0 || text2_length == 0) { - return 0 - } - // Truncate the longer string. - if (text1_length > text2_length) { - text1 = text1.substring(text1_length - text2_length) - } else if (text1_length < text2_length) { - text2 = text2.substring(0, text1_length) - } - val text_length = min(text1_length.toDouble(), text2_length.toDouble()).toInt() - // Quick check for the worst case. - if (text1 == text2) { - return text_length - } - - // Start by looking for a single character match - // and increase length until no match is found. - // Performance analysis: https://neil.fraser.name/news/2010/11/04/ - var best = 0 - var length = 1 - while (true) { - val pattern = text1.substring(text_length - length) - val found = text2.indexOf(pattern) - if (found == -1) { - return best - } - length += found - if (found == 0 || text1.substring(text_length - length) == text2.substring(0, length)) { - best = length - length++ - } - } + // Start by looking for a single character match + // and increase length until no match is found. + // Performance analysis: https://neil.fraser.name/news/2010/11/04/ + var best = 0 + var length = 1 + while (true) { + val pattern = text1.substring(text_length - length) + val found = text2.indexOf(pattern) + if (found == -1) { + return best + } + length += found + if (found == 0 || text1.substring(text_length - length) == text2.substring(0, length)) { + best = length + length++ + } + } + } + + /** + * Do the two texts share a substring which is at least half the length of + * the longer text? + * This speedup can produce non-minimal diffs. + * @param text1 First string. + * @param text2 Second string. + * @return Five element String array, containing the prefix of text1, the + * suffix of text1, the prefix of text2, the suffix of text2 and the + * common middle. Or null if there was no match. + */ + private fun diff_halfMatch(text1: String, text2: String): Array? { + if (Diff_Timeout <= 0) { + // Don't risk returning a non-optimal diff if we have unlimited time. + return null + } + val longtext = if (text1.length > text2.length) text1 else text2 + val shorttext = if (text1.length > text2.length) text2 else text1 + if (longtext.length < 4 || shorttext.length * 2 < longtext.length) { + return null // Pointless. } - /** - * Do the two texts share a substring which is at least half the length of - * the longer text? - * This speedup can produce non-minimal diffs. - * @param text1 First string. - * @param text2 Second string. - * @return Five element String array, containing the prefix of text1, the - * suffix of text1, the prefix of text2, the suffix of text2 and the - * common middle. Or null if there was no match. - */ - private fun diff_halfMatch(text1: String, text2: String): Array? { - if (Diff_Timeout <= 0) { - // Don't risk returning a non-optimal diff if we have unlimited time. - return null - } - val longtext = if (text1.length > text2.length) text1 else text2 - val shorttext = if (text1.length > text2.length) text2 else text1 - if (longtext.length < 4 || shorttext.length * 2 < longtext.length) { - return null // Pointless. - } + // First check if the second quarter is the seed for a half-match. + val hm1 = diff_halfMatchI( + longtext, shorttext, + (longtext.length + 3) / 4 + ) + // Check again based on the third quarter. + val hm2 = diff_halfMatchI( + longtext, shorttext, + (longtext.length + 1) / 2 + ) + val hm: Array? + if (hm1 == null && hm2 == null) { + return null + } else if (hm2 == null) { + hm = hm1 + } else if (hm1 == null) { + hm = hm2 + } else { + // Both matched. Select the longest. + hm = if (hm1[4].length > hm2[4].length) hm1 else hm2 + } - // First check if the second quarter is the seed for a half-match. - val hm1 = diff_halfMatchI( - longtext, shorttext, - (longtext.length + 3) / 4 - ) - // Check again based on the third quarter. - val hm2 = diff_halfMatchI( - longtext, shorttext, - (longtext.length + 1) / 2 - ) - val hm: Array? - if (hm1 == null && hm2 == null) { - return null - } else if (hm2 == null) { - hm = hm1 - } else if (hm1 == null) { - hm = hm2 + // A half-match was found, sort out the return data. + if (text1.length > text2.length) { + return hm + //return new String[]{hm[0], hm[1], hm[2], hm[3], hm[4]}; + } else { + return arrayOf(hm!![2], hm[3], hm[0], hm[1], hm[4]) + } + } + + /** + * Does a substring of shorttext exist within longtext such that the + * substring is at least half the length of longtext? + * @param longtext Longer string. + * @param shorttext Shorter string. + * @param i Start index of quarter length substring within longtext. + * @return Five element String array, containing the prefix of longtext, the + * suffix of longtext, the prefix of shorttext, the suffix of shorttext + * and the common middle. Or null if there was no match. + */ + private fun diff_halfMatchI(longtext: String, shorttext: String, i: Int): Array? { + // Start with a 1/4 length substring at position i as a seed. + val seed = longtext.substring(i, i + longtext.length / 4) + var j = -1 + var best_common = "" + var best_longtext_a = "" + var best_longtext_b = "" + var best_shorttext_a = "" + var best_shorttext_b = "" + while ((shorttext.indexOf(seed, j + 1).also { j = it }) != -1) { + val prefixLength = diff_commonPrefix( + longtext.substring(i), + shorttext.substring(j) + ) + val suffixLength = diff_commonSuffix( + longtext.substring(0, i), + shorttext.substring(0, j) + ) + if (best_common.length < suffixLength + prefixLength) { + best_common = (shorttext.substring(j - suffixLength, j) + + shorttext.substring(j, j + prefixLength)) + best_longtext_a = longtext.substring(0, i - suffixLength) + best_longtext_b = longtext.substring(i + prefixLength) + best_shorttext_a = shorttext.substring(0, j - suffixLength) + best_shorttext_b = shorttext.substring(j + prefixLength) + } + } + if (best_common.length * 2 >= longtext.length) { + return arrayOf( + best_longtext_a, best_longtext_b, + best_shorttext_a, best_shorttext_b, best_common + ) + } else { + return null + } + } + + /** + * Reduce the number of edits by eliminating semantically trivial equalities. + * @param diffs LinkedList of Diff objects. + */ + fun diff_cleanupSemantic(diffs: LinkedList) { + if (diffs.isEmpty()) { + return + } + var changes = false + val equalities = ArrayDeque() // Double-ended queue of qualities. + var lastEquality: String? = null // Always equal to equalities.peek().text + var pointer = diffs.listIterator() + // Number of characters that changed prior to the equality. + var length_insertions1 = 0 + var length_deletions1 = 0 + // Number of characters that changed after the equality. + var length_insertions2 = 0 + var length_deletions2 = 0 + var thisDiff: Diff? = pointer.next() + while (thisDiff != null) { + if (thisDiff.operation == Operation.EQUAL) { + // Equality found. + equalities.add(thisDiff) + length_insertions1 = length_insertions2 + length_deletions1 = length_deletions2 + length_insertions2 = 0 + length_deletions2 = 0 + lastEquality = thisDiff.text + } else { + // An insertion or deletion. + if (thisDiff.operation == Operation.INSERT) { + length_insertions2 += thisDiff.text!!.length } else { - // Both matched. Select the longest. - hm = if (hm1[4].length > hm2[4].length) hm1 else hm2 - } + length_deletions2 += thisDiff.text!!.length + } + // Eliminate an equality that is smaller or equal to the edits on both + // sides of it. + if (lastEquality != null && (lastEquality.length + <= max(length_insertions1.toDouble(), length_deletions1.toDouble())) + && (lastEquality.length + <= max(length_insertions2.toDouble(), length_deletions2.toDouble())) + ) { + //System.out.println("Splitting: '" + lastEquality + "'"); + // Walk back to offending equality. + while (thisDiff !== equalities.peek()) { + thisDiff = pointer.previous() + } + pointer.next() + + // Replace equality with a delete. + pointer.set(Diff(Operation.DELETE, lastEquality)) + // Insert a corresponding an insert. + pointer.add(Diff(Operation.INSERT, lastEquality)) + + equalities.pop() // Throw away the equality we just deleted. + if (!equalities.isEmpty()) { + // Throw away the previous equality (it needs to be reevaluated). + equalities.pop() + } + if (equalities.isEmpty()) { + // There are no previous equalities, walk back to the start. + while (pointer.hasPrevious()) { + pointer.previous() + } + } else { + // There is a safe equality we can fall back to. + thisDiff = equalities.peek() + while (thisDiff !== pointer.previous()) { + // Intentionally empty loop. + } + } - // A half-match was found, sort out the return data. - if (text1.length > text2.length) { - return hm - //return new String[]{hm[0], hm[1], hm[2], hm[3], hm[4]}; - } else { - return arrayOf(hm!![2], hm[3], hm[0], hm[1], hm[4]) + length_insertions1 = 0 // Reset the counters. + length_insertions2 = 0 + length_deletions1 = 0 + length_deletions2 = 0 + lastEquality = null + changes = true } + } + thisDiff = if (pointer.hasNext()) pointer.next() else null } - /** - * Does a substring of shorttext exist within longtext such that the - * substring is at least half the length of longtext? - * @param longtext Longer string. - * @param shorttext Shorter string. - * @param i Start index of quarter length substring within longtext. - * @return Five element String array, containing the prefix of longtext, the - * suffix of longtext, the prefix of shorttext, the suffix of shorttext - * and the common middle. Or null if there was no match. - */ - private fun diff_halfMatchI(longtext: String, shorttext: String, i: Int): Array? { - // Start with a 1/4 length substring at position i as a seed. - val seed = longtext.substring(i, i + longtext.length / 4) - var j = -1 - var best_common = "" - var best_longtext_a = "" - var best_longtext_b = "" - var best_shorttext_a = "" - var best_shorttext_b = "" - while ((shorttext.indexOf(seed, j + 1).also { j = it }) != -1) { - val prefixLength = diff_commonPrefix( - longtext.substring(i), - shorttext.substring(j) - ) - val suffixLength = diff_commonSuffix( - longtext.substring(0, i), - shorttext.substring(0, j) - ) - if (best_common.length < suffixLength + prefixLength) { - best_common = (shorttext.substring(j - suffixLength, j) - + shorttext.substring(j, j + prefixLength)) - best_longtext_a = longtext.substring(0, i - suffixLength) - best_longtext_b = longtext.substring(i + prefixLength) - best_shorttext_a = shorttext.substring(0, j - suffixLength) - best_shorttext_b = shorttext.substring(j + prefixLength) - } - } - if (best_common.length * 2 >= longtext.length) { - return arrayOf( - best_longtext_a, best_longtext_b, - best_shorttext_a, best_shorttext_b, best_common + // Normalize the diff. + if (changes) { + diff_cleanupMerge(diffs) + } + diff_cleanupSemanticLossless(diffs) + + // Find any overlaps between deletions and insertions. + // e.g: abcxxxxxxdef + // -> abcxxxdef + // e.g: xxxabcdefxxx + // -> defxxxabc + // Only extract an overlap if it is as big as the edit ahead or behind it. + pointer = diffs.listIterator() + var prevDiff: Diff? = null + thisDiff = null + if (pointer.hasNext()) { + prevDiff = pointer.next() + if (pointer.hasNext()) { + thisDiff = pointer.next() + } + } + while (thisDiff != null) { + if (prevDiff!!.operation == Operation.DELETE && + thisDiff.operation == Operation.INSERT + ) { + val deletion = prevDiff.text + val insertion = thisDiff.text + val overlap_length1 = this.diff_commonOverlap(deletion, insertion) + val overlap_length2 = this.diff_commonOverlap(insertion, deletion) + if (overlap_length1 >= overlap_length2) { + if (overlap_length1 >= deletion!!.length / 2.0 || + overlap_length1 >= insertion!!.length / 2.0 + ) { + // Overlap found. Insert an equality and trim the surrounding edits. + pointer.previous() + pointer.add( + Diff( + Operation.EQUAL, + insertion!!.substring(0, overlap_length1) + ) ) + prevDiff.text = + deletion.substring(0, deletion.length - overlap_length1) + thisDiff.text = insertion.substring(overlap_length1) + // pointer.add inserts the element before the cursor, so there is + // no need to step past the new element. + } } else { - return null + if (overlap_length2 >= deletion!!.length / 2.0 || + overlap_length2 >= insertion!!.length / 2.0 + ) { + // Reverse overlap found. + // Insert an equality and swap and trim the surrounding edits. + pointer.previous() + pointer.add( + Diff( + Operation.EQUAL, + deletion.substring(0, overlap_length2) + ) + ) + prevDiff.operation = Operation.INSERT + prevDiff.text = + insertion!!.substring(0, insertion.length - overlap_length2) + thisDiff.operation = Operation.DELETE + thisDiff.text = deletion.substring(overlap_length2) + // pointer.add inserts the element before the cursor, so there is + // no need to step past the new element. + } } + thisDiff = if (pointer.hasNext()) pointer.next() else null + } + prevDiff = thisDiff + thisDiff = if (pointer.hasNext()) pointer.next() else null } - - /** - * Reduce the number of edits by eliminating semantically trivial equalities. - * @param diffs LinkedList of Diff objects. - */ - fun diff_cleanupSemantic(diffs: LinkedList) { - if (diffs.isEmpty()) { - return - } - var changes = false - val equalities = ArrayDeque() // Double-ended queue of qualities. - var lastEquality: String? = null // Always equal to equalities.peek().text - var pointer = diffs.listIterator() - // Number of characters that changed prior to the equality. - var length_insertions1 = 0 - var length_deletions1 = 0 - // Number of characters that changed after the equality. - var length_insertions2 = 0 - var length_deletions2 = 0 - var thisDiff: Diff? = pointer.next() - while (thisDiff != null) { - if (thisDiff.operation == Operation.EQUAL) { - // Equality found. - equalities.add(thisDiff) - length_insertions1 = length_insertions2 - length_deletions1 = length_deletions2 - length_insertions2 = 0 - length_deletions2 = 0 - lastEquality = thisDiff.text - } else { - // An insertion or deletion. - if (thisDiff.operation == Operation.INSERT) { - length_insertions2 += thisDiff.text!!.length - } else { - length_deletions2 += thisDiff.text!!.length - } - // Eliminate an equality that is smaller or equal to the edits on both - // sides of it. - if (lastEquality != null && (lastEquality.length - <= max(length_insertions1.toDouble(), length_deletions1.toDouble())) - && (lastEquality.length - <= max(length_insertions2.toDouble(), length_deletions2.toDouble())) - ) { - //System.out.println("Splitting: '" + lastEquality + "'"); - // Walk back to offending equality. - while (thisDiff !== equalities.peek()) { - thisDiff = pointer.previous() - } - pointer.next() - - // Replace equality with a delete. - pointer.set(Diff(Operation.DELETE, lastEquality)) - // Insert a corresponding an insert. - pointer.add(Diff(Operation.INSERT, lastEquality)) - - equalities.pop() // Throw away the equality we just deleted. - if (!equalities.isEmpty()) { - // Throw away the previous equality (it needs to be reevaluated). - equalities.pop() - } - if (equalities.isEmpty()) { - // There are no previous equalities, walk back to the start. - while (pointer.hasPrevious()) { - pointer.previous() - } - } else { - // There is a safe equality we can fall back to. - thisDiff = equalities.peek() - while (thisDiff !== pointer.previous()) { - // Intentionally empty loop. - } - } - - length_insertions1 = 0 // Reset the counters. - length_insertions2 = 0 - length_deletions1 = 0 - length_deletions2 = 0 - lastEquality = null - changes = true - } - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - - // Normalize the diff. - if (changes) { - diff_cleanupMerge(diffs) - } - diff_cleanupSemanticLossless(diffs) - - // Find any overlaps between deletions and insertions. - // e.g: abcxxxxxxdef - // -> abcxxxdef - // e.g: xxxabcdefxxx - // -> defxxxabc - // Only extract an overlap if it is as big as the edit ahead or behind it. - pointer = diffs.listIterator() - var prevDiff: Diff? = null - thisDiff = null - if (pointer.hasNext()) { - prevDiff = pointer.next() - if (pointer.hasNext()) { - thisDiff = pointer.next() - } - } - while (thisDiff != null) { - if (prevDiff!!.operation == Operation.DELETE && - thisDiff.operation == Operation.INSERT - ) { - val deletion = prevDiff.text - val insertion = thisDiff.text - val overlap_length1 = this.diff_commonOverlap(deletion, insertion) - val overlap_length2 = this.diff_commonOverlap(insertion, deletion) - if (overlap_length1 >= overlap_length2) { - if (overlap_length1 >= deletion!!.length / 2.0 || - overlap_length1 >= insertion!!.length / 2.0 - ) { - // Overlap found. Insert an equality and trim the surrounding edits. - pointer.previous() - pointer.add( - Diff( - Operation.EQUAL, - insertion!!.substring(0, overlap_length1) - ) - ) - prevDiff.text = - deletion.substring(0, deletion.length - overlap_length1) - thisDiff.text = insertion.substring(overlap_length1) - // pointer.add inserts the element before the cursor, so there is - // no need to step past the new element. - } - } else { - if (overlap_length2 >= deletion!!.length / 2.0 || - overlap_length2 >= insertion!!.length / 2.0 - ) { - // Reverse overlap found. - // Insert an equality and swap and trim the surrounding edits. - pointer.previous() - pointer.add( - Diff( - Operation.EQUAL, - deletion.substring(0, overlap_length2) - ) - ) - prevDiff.operation = Operation.INSERT - prevDiff.text = - insertion!!.substring(0, insertion.length - overlap_length2) - thisDiff.operation = Operation.DELETE - thisDiff.text = deletion.substring(overlap_length2) - // pointer.add inserts the element before the cursor, so there is - // no need to step past the new element. - } - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - prevDiff = thisDiff - thisDiff = if (pointer.hasNext()) pointer.next() else null - } + } + + /** + * Look for single edits surrounded on both sides by equalities + * which can be shifted sideways to align the edit to a word boundary. + * e.g: The cat came. -> The cat came. + * @param diffs LinkedList of Diff objects. + */ + private fun diff_cleanupSemanticLossless(diffs: LinkedList) { + var equality1: String + var edit: String + var equality2: String + var commonString: String + var commonOffset: Int + var score: Int + var bestScore: Int + var bestEquality1: String? + var bestEdit: String? + var bestEquality2: String? + // Create a new iterator at the start. + val pointer = diffs.listIterator() + var prevDiff = if (pointer.hasNext()) pointer.next() else null + var thisDiff = if (pointer.hasNext()) pointer.next() else null + var nextDiff = if (pointer.hasNext()) pointer.next() else null + // Intentionally ignore the first and last element (don't need checking). + while (nextDiff != null) { + if (prevDiff!!.operation == Operation.EQUAL && + nextDiff.operation == Operation.EQUAL + ) { + // This is a single edit surrounded by equalities. + equality1 = prevDiff.text!! + edit = thisDiff!!.text!! + equality2 = nextDiff.text!! + + // First, shift the edit as far left as possible. + commonOffset = diff_commonSuffix(equality1, edit) + if (commonOffset != 0) { + commonString = edit.substring(edit.length - commonOffset) + equality1 = equality1.substring(0, equality1.length - commonOffset) + edit = commonString + edit.substring(0, edit.length - commonOffset) + equality2 = commonString + equality2 + } + + // Second, step character by character right, looking for the best fit. + bestEquality1 = equality1 + bestEdit = edit + bestEquality2 = equality2 + bestScore = (diff_cleanupSemanticScore(equality1, edit) + + diff_cleanupSemanticScore(edit, equality2)) + while (((edit.length != 0) && equality2.length != 0) && edit[0] == equality2[0]) { + equality1 += edit[0] + edit = edit.substring(1) + equality2[0] + equality2 = equality2.substring(1) + score = (diff_cleanupSemanticScore(equality1, edit) + + diff_cleanupSemanticScore(edit, equality2)) + // The >= encourages trailing rather than leading whitespace on edits. + if (score >= bestScore) { + bestScore = score + bestEquality1 = equality1 + bestEdit = edit + bestEquality2 = equality2 + } + } + + if (prevDiff.text != bestEquality1) { + // We have an improvement, save it back to the diff. + if (bestEquality1!!.length != 0) { + prevDiff.text = bestEquality1 + } else { + pointer.previous() // Walk past nextDiff. + pointer.previous() // Walk past thisDiff. + pointer.previous() // Walk past prevDiff. + pointer.remove() // Delete prevDiff. + pointer.next() // Walk past thisDiff. + pointer.next() // Walk past nextDiff. + } + thisDiff.text = bestEdit + if (bestEquality2!!.length != 0) { + nextDiff.text = bestEquality2 + } else { + pointer.remove() // Delete nextDiff. + nextDiff = thisDiff + thisDiff = prevDiff + } + } + } + prevDiff = thisDiff + thisDiff = nextDiff + nextDiff = if (pointer.hasNext()) pointer.next() else null } - - /** - * Look for single edits surrounded on both sides by equalities - * which can be shifted sideways to align the edit to a word boundary. - * e.g: The cat came. -> The cat came. - * @param diffs LinkedList of Diff objects. - */ - private fun diff_cleanupSemanticLossless(diffs: LinkedList) { - var equality1: String - var edit: String - var equality2: String - var commonString: String - var commonOffset: Int - var score: Int - var bestScore: Int - var bestEquality1: String? - var bestEdit: String? - var bestEquality2: String? - // Create a new iterator at the start. - val pointer = diffs.listIterator() - var prevDiff = if (pointer.hasNext()) pointer.next() else null - var thisDiff = if (pointer.hasNext()) pointer.next() else null - var nextDiff = if (pointer.hasNext()) pointer.next() else null - // Intentionally ignore the first and last element (don't need checking). - while (nextDiff != null) { - if (prevDiff!!.operation == Operation.EQUAL && - nextDiff.operation == Operation.EQUAL - ) { - // This is a single edit surrounded by equalities. - equality1 = prevDiff.text!! - edit = thisDiff!!.text!! - equality2 = nextDiff.text!! - - // First, shift the edit as far left as possible. - commonOffset = diff_commonSuffix(equality1, edit) - if (commonOffset != 0) { - commonString = edit.substring(edit.length - commonOffset) - equality1 = equality1.substring(0, equality1.length - commonOffset) - edit = commonString + edit.substring(0, edit.length - commonOffset) - equality2 = commonString + equality2 - } - - // Second, step character by character right, looking for the best fit. - bestEquality1 = equality1 - bestEdit = edit - bestEquality2 = equality2 - bestScore = (diff_cleanupSemanticScore(equality1, edit) - + diff_cleanupSemanticScore(edit, equality2)) - while (((edit.length != 0) && equality2.length != 0) && edit[0] == equality2[0]) { - equality1 += edit[0] - edit = edit.substring(1) + equality2[0] - equality2 = equality2.substring(1) - score = (diff_cleanupSemanticScore(equality1, edit) - + diff_cleanupSemanticScore(edit, equality2)) - // The >= encourages trailing rather than leading whitespace on edits. - if (score >= bestScore) { - bestScore = score - bestEquality1 = equality1 - bestEdit = edit - bestEquality2 = equality2 - } - } - - if (prevDiff.text != bestEquality1) { - // We have an improvement, save it back to the diff. - if (bestEquality1!!.length != 0) { - prevDiff.text = bestEquality1 - } else { - pointer.previous() // Walk past nextDiff. - pointer.previous() // Walk past thisDiff. - pointer.previous() // Walk past prevDiff. - pointer.remove() // Delete prevDiff. - pointer.next() // Walk past thisDiff. - pointer.next() // Walk past nextDiff. - } - thisDiff.text = bestEdit - if (bestEquality2!!.length != 0) { - nextDiff.text = bestEquality2 - } else { - pointer.remove() // Delete nextDiff. - nextDiff = thisDiff - thisDiff = prevDiff - } - } - } - prevDiff = thisDiff - thisDiff = nextDiff - nextDiff = if (pointer.hasNext()) pointer.next() else null - } + } + + /** + * Given two strings, compute a score representing whether the internal + * boundary falls on logical boundaries. + * Scores range from 6 (best) to 0 (worst). + * @param one First string. + * @param two Second string. + * @return The score. + */ + private fun diff_cleanupSemanticScore(one: String?, two: String?): Int { + if (one!!.length == 0 || two!!.length == 0) { + // Edges are the best. + return 6 } - /** - * Given two strings, compute a score representing whether the internal - * boundary falls on logical boundaries. - * Scores range from 6 (best) to 0 (worst). - * @param one First string. - * @param two Second string. - * @return The score. - */ - private fun diff_cleanupSemanticScore(one: String?, two: String?): Int { - if (one!!.length == 0 || two!!.length == 0) { - // Edges are the best. - return 6 - } - - // Each port of this function behaves slightly differently due to - // subtle differences in each language's definition of things like - // 'whitespace'. Since this function's purpose is largely cosmetic, - // the choice has been made to use each language's native features - // rather than force total conformity. - val char1 = one[one.length - 1] - val char2 = two[0] - val nonAlphaNumeric1 = !Character.isLetterOrDigit(char1) - val nonAlphaNumeric2 = !Character.isLetterOrDigit(char2) - val whitespace1 = nonAlphaNumeric1 && Character.isWhitespace(char1) - val whitespace2 = nonAlphaNumeric2 && Character.isWhitespace(char2) - val lineBreak1 = (whitespace1 - && Character.getType(char1) == Character.CONTROL.toInt()) - val lineBreak2 = (whitespace2 - && Character.getType(char2) == Character.CONTROL.toInt()) - val blankLine1 = lineBreak1 && BLANKLINEEND.matcher(one).find() - val blankLine2 = lineBreak2 && BLANKLINESTART.matcher(two).find() - - if (blankLine1 || blankLine2) { - // Five points for blank lines. - return 5 - } else if (lineBreak1 || lineBreak2) { - // Four points for line breaks. - return 4 - } else if (nonAlphaNumeric1 && !whitespace1 && whitespace2) { - // Three points for end of sentences. - return 3 - } else if (whitespace1 || whitespace2) { - // Two points for whitespace. - return 2 - } else if (nonAlphaNumeric1 || nonAlphaNumeric2) { - // One point for non-alphanumeric. - return 1 - } - return 0 + // Each port of this function behaves slightly differently due to + // subtle differences in each language's definition of things like + // 'whitespace'. Since this function's purpose is largely cosmetic, + // the choice has been made to use each language's native features + // rather than force total conformity. + val char1 = one[one.length - 1] + val char2 = two[0] + val nonAlphaNumeric1 = !Character.isLetterOrDigit(char1) + val nonAlphaNumeric2 = !Character.isLetterOrDigit(char2) + val whitespace1 = nonAlphaNumeric1 && Character.isWhitespace(char1) + val whitespace2 = nonAlphaNumeric2 && Character.isWhitespace(char2) + val lineBreak1 = (whitespace1 + && Character.getType(char1) == Character.CONTROL.toInt()) + val lineBreak2 = (whitespace2 + && Character.getType(char2) == Character.CONTROL.toInt()) + val blankLine1 = lineBreak1 && BLANKLINEEND.matcher(one).find() + val blankLine2 = lineBreak2 && BLANKLINESTART.matcher(two).find() + + if (blankLine1 || blankLine2) { + // Five points for blank lines. + return 5 + } else if (lineBreak1 || lineBreak2) { + // Four points for line breaks. + return 4 + } else if (nonAlphaNumeric1 && !whitespace1 && whitespace2) { + // Three points for end of sentences. + return 3 + } else if (whitespace1 || whitespace2) { + // Two points for whitespace. + return 2 + } else if (nonAlphaNumeric1 || nonAlphaNumeric2) { + // One point for non-alphanumeric. + return 1 } - - // Define some regex patterns for matching boundaries. - private val BLANKLINEEND - : Pattern = Pattern.compile("\\n\\r?\\n\\Z", Pattern.DOTALL) - private val BLANKLINESTART - : Pattern = Pattern.compile("\\A\\r?\\n\\r?\\n", Pattern.DOTALL) - - /** - * Reduce the number of edits by eliminating operationally trivial equalities. - * @param diffs LinkedList of Diff objects. - */ - fun diff_cleanupEfficiency(diffs: LinkedList) { - if (diffs.isEmpty()) { - return - } - var changes = false - val equalities = ArrayDeque() // Double-ended queue of equalities. - var lastEquality: String? = null // Always equal to equalities.peek().text - val pointer = diffs.listIterator() - // Is there an insertion operation before the last equality. - var pre_ins = false - // Is there a deletion operation before the last equality. - var pre_del = false - // Is there an insertion operation after the last equality. - var post_ins = false - // Is there a deletion operation after the last equality. - var post_del = false - var thisDiff: Diff? = pointer.next() - var safeDiff = thisDiff // The last Diff that is known to be unsplittable. - while (thisDiff != null) { - if (thisDiff.operation == Operation.EQUAL) { - // Equality found. - if (thisDiff.text!!.length < Diff_EditCost && (post_ins || post_del)) { - // Candidate found. - equalities.push(thisDiff) - pre_ins = post_ins - pre_del = post_del - lastEquality = thisDiff.text - } else { - // Not a candidate, and can never become one. - equalities.clear() - lastEquality = null - safeDiff = thisDiff - } - post_del = false - post_ins = post_del - } else { - // An insertion or deletion. - if (thisDiff.operation == Operation.DELETE) { - post_del = true - } else { - post_ins = true - } - /* - * Five types to be split: - * ABXYCD - * AXCD - * ABXC - * AXCD - * ABXC - */ - if (lastEquality != null - && ((pre_ins && pre_del && post_ins && post_del) - || ((lastEquality.length < Diff_EditCost / 2) - && ((if (pre_ins) 1 else 0) + (if (pre_del) 1 else 0) - + (if (post_ins) 1 else 0) + (if (post_del) 1 else 0)) == 3)) - ) { - //System.out.println("Splitting: '" + lastEquality + "'"); - // Walk back to offending equality. - while (thisDiff !== equalities.peek()) { - thisDiff = pointer.previous() - } - pointer.next() - - // Replace equality with a delete. - pointer.set(Diff(Operation.DELETE, lastEquality)) - // Insert a corresponding an insert. - pointer.add(Diff(Operation.INSERT, lastEquality).also { - thisDiff = it - }) - - equalities.pop() // Throw away the equality we just deleted. - lastEquality = null - if (pre_ins && pre_del) { - // No changes made which could affect previous entry, keep going. - post_del = true - post_ins = post_del - equalities.clear() - safeDiff = thisDiff - } else { - if (!equalities.isEmpty()) { - // Throw away the previous equality (it needs to be reevaluated). - equalities.pop() - } - if (equalities.isEmpty()) { - // There are no previous questionable equalities, - // walk back to the last known safe diff. - thisDiff = safeDiff - } else { - // There is an equality we can fall back to. - thisDiff = equalities.peek() - } - while (thisDiff !== pointer.previous()) { - // Intentionally empty loop. - } - post_del = false - post_ins = post_del - } - - changes = true - } - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - - if (changes) { - diff_cleanupMerge(diffs) - } + return 0 + } + + // Define some regex patterns for matching boundaries. + private val BLANKLINEEND + : Pattern = Pattern.compile("\\n\\r?\\n\\Z", Pattern.DOTALL) + private val BLANKLINESTART + : Pattern = Pattern.compile("\\A\\r?\\n\\r?\\n", Pattern.DOTALL) + + /** + * Reduce the number of edits by eliminating operationally trivial equalities. + * @param diffs LinkedList of Diff objects. + */ + fun diff_cleanupEfficiency(diffs: LinkedList) { + if (diffs.isEmpty()) { + return } - - /** - * Reorder and merge like edit sections. Merge equalities. - * Any edit section can move as long as it doesn't cross an equality. - * @param diffs LinkedList of Diff objects. - */ - private fun diff_cleanupMerge(diffs: LinkedList) { - diffs.add(Diff(Operation.EQUAL, "")) // Add a dummy entry at the end. - var pointer = diffs.listIterator() - var count_delete = 0 - var count_insert = 0 - var text_delete: String? = "" - var text_insert: String? = "" - var thisDiff: Diff? = pointer.next() - var prevEqual: Diff? = null - var commonlength: Int - while (thisDiff != null) { - when (thisDiff.operation) { - Operation.INSERT -> { - count_insert++ - text_insert += thisDiff.text - prevEqual = null - } - - Operation.DELETE -> { - count_delete++ - text_delete += thisDiff.text - prevEqual = null - } - - Operation.EQUAL -> { - if (count_delete + count_insert > 1) { - val both_types = count_delete != 0 && count_insert != 0 - // Delete the offending records. - pointer.previous() // Reverse direction. - while (count_delete-- > 0) { - pointer.previous() - pointer.remove() - } - while (count_insert-- > 0) { - pointer.previous() - pointer.remove() - } - if (both_types) { - // Factor out any common prefixies. - commonlength = diff_commonPrefix(text_insert, text_delete) - if (commonlength != 0) { - if (pointer.hasPrevious()) { - thisDiff = pointer.previous() - assert( - thisDiff.operation == Operation.EQUAL - ) { "Previous diff should have been an equality." } - thisDiff.text += text_insert!!.substring(0, commonlength) - pointer.next() - } else { - pointer.add( - Diff( - Operation.EQUAL, - text_insert!!.substring(0, commonlength) - ) - ) - } - text_insert = text_insert.substring(commonlength) - text_delete = text_delete!!.substring(commonlength) - } - // Factor out any common suffixies. - commonlength = diff_commonSuffix(text_insert, text_delete) - if (commonlength != 0) { - thisDiff = pointer.next() - thisDiff.text = text_insert!!.substring( - text_insert.length - - commonlength - ) + thisDiff.text - text_insert = text_insert.substring( - 0, text_insert.length - - commonlength - ) - text_delete = text_delete!!.substring( - 0, text_delete.length - - commonlength - ) - pointer.previous() - } - } - // Insert the merged records. - if (text_delete!!.length != 0) { - pointer.add(Diff(Operation.DELETE, text_delete)) - } - if (text_insert!!.length != 0) { - pointer.add(Diff(Operation.INSERT, text_insert)) - } - // Step forward to the equality. - thisDiff = if (pointer.hasNext()) pointer.next() else null - } else if (prevEqual != null) { - // Merge this equality with the previous one. - prevEqual.text += thisDiff.text - pointer.remove() - thisDiff = pointer.previous() - pointer.next() // Forward direction - } - count_insert = 0 - count_delete = 0 - text_delete = "" - text_insert = "" - prevEqual = thisDiff - } - - null -> TODO() - } - thisDiff = if (pointer.hasNext()) pointer.next() else null - } - if (diffs.last.text!!.length == 0) { - diffs.removeLast() // Remove the dummy entry at the end. + var changes = false + val equalities = ArrayDeque() // Double-ended queue of equalities. + var lastEquality: String? = null // Always equal to equalities.peek().text + val pointer = diffs.listIterator() + // Is there an insertion operation before the last equality. + var pre_ins = false + // Is there a deletion operation before the last equality. + var pre_del = false + // Is there an insertion operation after the last equality. + var post_ins = false + // Is there a deletion operation after the last equality. + var post_del = false + var thisDiff: Diff? = pointer.next() + var safeDiff = thisDiff // The last Diff that is known to be unsplittable. + while (thisDiff != null) { + if (thisDiff.operation == Operation.EQUAL) { + // Equality found. + if (thisDiff.text!!.length < Diff_EditCost && (post_ins || post_del)) { + // Candidate found. + equalities.push(thisDiff) + pre_ins = post_ins + pre_del = post_del + lastEquality = thisDiff.text + } else { + // Not a candidate, and can never become one. + equalities.clear() + lastEquality = null + safeDiff = thisDiff + } + post_del = false + post_ins = post_del + } else { + // An insertion or deletion. + if (thisDiff.operation == Operation.DELETE) { + post_del = true + } else { + post_ins = true } - /* - * Second pass: look for single edits surrounded on both sides by equalities - * which can be shifted sideways to eliminate an equality. - * e.g: ABAC -> ABAC - */ - var changes = false - // Create a new iterator at the start. - // (As opposed to walking the current one back.) - pointer = diffs.listIterator() - var prevDiff = if (pointer.hasNext()) pointer.next() else null - thisDiff = if (pointer.hasNext()) pointer.next() else null - var nextDiff = if (pointer.hasNext()) pointer.next() else null - // Intentionally ignore the first and last element (don't need checking). - while (nextDiff != null) { - if (prevDiff!!.operation == Operation.EQUAL && - nextDiff.operation == Operation.EQUAL - ) { - // This is a single edit surrounded by equalities. - if (thisDiff!!.text!!.endsWith(prevDiff.text!!)) { - // Shift the edit over the previous equality. - thisDiff.text = (prevDiff.text - + thisDiff.text!!.substring( - 0, thisDiff.text!!.length - - prevDiff.text!!.length - )) - nextDiff.text = prevDiff.text + nextDiff.text - pointer.previous() // Walk past nextDiff. - pointer.previous() // Walk past thisDiff. - pointer.previous() // Walk past prevDiff. - pointer.remove() // Delete prevDiff. - pointer.next() // Walk past thisDiff. - thisDiff = pointer.next() // Walk past nextDiff. - nextDiff = if (pointer.hasNext()) pointer.next() else null - changes = true - } else if (thisDiff.text!!.startsWith(nextDiff.text!!)) { - // Shift the edit over the next equality. - prevDiff.text += nextDiff.text - thisDiff.text = (thisDiff.text!!.substring(nextDiff.text!!.length) - + nextDiff.text) - pointer.remove() // Delete nextDiff. - nextDiff = if (pointer.hasNext()) pointer.next() else null - changes = true - } - } - prevDiff = thisDiff - thisDiff = nextDiff - nextDiff = if (pointer.hasNext()) pointer.next() else null - } - // If shifts were made, the diff needs reordering and another shift sweep. - if (changes) { - diff_cleanupMerge(diffs) - } - } - - /** - * loc is a location in text1, compute and return the equivalent location in - * text2. - * e.g. "The cat" vs "The big cat", 1->1, 5->8 - * @param diffs List of Diff objects. - * @param loc Location within text1. - * @return Location within text2. - */ - private fun diff_xIndex(diffs: List, loc: Int): Int { - var chars1 = 0 - var chars2 = 0 - var last_chars1 = 0 - var last_chars2 = 0 - var lastDiff: Diff? = null - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.INSERT) { - // Equality or deletion. - chars1 += aDiff.text!!.length + * Five types to be split: + * ABXYCD + * AXCD + * ABXC + * AXCD + * ABXC + */ + if (lastEquality != null + && ((pre_ins && pre_del && post_ins && post_del) + || ((lastEquality.length < Diff_EditCost / 2) + && ((if (pre_ins) 1 else 0) + (if (pre_del) 1 else 0) + + (if (post_ins) 1 else 0) + (if (post_del) 1 else 0)) == 3)) + ) { + //System.out.println("Splitting: '" + lastEquality + "'"); + // Walk back to offending equality. + while (thisDiff !== equalities.peek()) { + thisDiff = pointer.previous() + } + pointer.next() + + // Replace equality with a delete. + pointer.set(Diff(Operation.DELETE, lastEquality)) + // Insert a corresponding an insert. + pointer.add(Diff(Operation.INSERT, lastEquality).also { + thisDiff = it + }) + + equalities.pop() // Throw away the equality we just deleted. + lastEquality = null + if (pre_ins && pre_del) { + // No changes made which could affect previous entry, keep going. + post_del = true + post_ins = post_del + equalities.clear() + safeDiff = thisDiff + } else { + if (!equalities.isEmpty()) { + // Throw away the previous equality (it needs to be reevaluated). + equalities.pop() } - if (aDiff.operation != Operation.DELETE) { - // Equality or insertion. - chars2 += aDiff.text!!.length + if (equalities.isEmpty()) { + // There are no previous questionable equalities, + // walk back to the last known safe diff. + thisDiff = safeDiff + } else { + // There is an equality we can fall back to. + thisDiff = equalities.peek() } - if (chars1 > loc) { - // Overshot the location. - lastDiff = aDiff - break + while (thisDiff !== pointer.previous()) { + // Intentionally empty loop. } - last_chars1 = chars1 - last_chars2 = chars2 - } - if (lastDiff != null && lastDiff.operation == Operation.DELETE) { - // The location was deleted. - return last_chars2 - } - // Add the remaining character length. - return last_chars2 + (loc - last_chars1) - } + post_del = false + post_ins = post_del + } - /** - * Compute and return the source text (all equalities and deletions). - * @param diffs List of Diff objects. - * @return Source text. - */ - private fun diff_text1(diffs: List): String { - val text = StringBuilder() - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.INSERT) { - text.append(aDiff.text) - } + changes = true } - return text.toString() + } + thisDiff = if (pointer.hasNext()) pointer.next() else null } - /** - * Compute and return the destination text (all equalities and insertions). - * @param diffs List of Diff objects. - * @return Destination text. - */ - private fun diff_text2(diffs: List): String { - val text = StringBuilder() - for (aDiff: Diff in diffs) { - if (aDiff.operation != Operation.DELETE) { - text.append(aDiff.text) - } - } - return text.toString() + if (changes) { + diff_cleanupMerge(diffs) } - - /** - * Compute the Levenshtein distance; the number of inserted, deleted or - * substituted characters. - * @param diffs List of Diff objects. - * @return Number of changes. - */ - private fun diff_levenshtein(diffs: List): Int { - var levenshtein = 0 - var insertions = 0 - var deletions = 0 - for (aDiff: Diff in diffs) { - when (aDiff.operation) { - Operation.INSERT -> insertions += aDiff.text!!.length - Operation.DELETE -> deletions += aDiff.text!!.length - Operation.EQUAL -> { - // A deletion and an insertion is one substitution. - levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() - insertions = 0 - deletions = 0 - } - - null -> TODO() + } + + /** + * Reorder and merge like edit sections. Merge equalities. + * Any edit section can move as long as it doesn't cross an equality. + * @param diffs LinkedList of Diff objects. + */ + private fun diff_cleanupMerge(diffs: LinkedList) { + diffs.add(Diff(Operation.EQUAL, "")) // Add a dummy entry at the end. + var pointer = diffs.listIterator() + var count_delete = 0 + var count_insert = 0 + var text_delete: String? = "" + var text_insert: String? = "" + var thisDiff: Diff? = pointer.next() + var prevEqual: Diff? = null + var commonlength: Int + while (thisDiff != null) { + when (thisDiff.operation) { + Operation.INSERT -> { + count_insert++ + text_insert += thisDiff.text + prevEqual = null + } + + Operation.DELETE -> { + count_delete++ + text_delete += thisDiff.text + prevEqual = null + } + + Operation.EQUAL -> { + if (count_delete + count_insert > 1) { + val both_types = count_delete != 0 && count_insert != 0 + // Delete the offending records. + pointer.previous() // Reverse direction. + while (count_delete-- > 0) { + pointer.previous() + pointer.remove() } - } - levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() - return levenshtein - } - - - // MATCH FUNCTIONS - /** - * Locate the best instance of 'pattern' in 'text' near 'loc'. - * Returns -1 if no match found. - * @param text The text to search. - * @param pattern The pattern to search for. - * @param loc The location to search around. - * @return Best match index or -1. - */ - private fun match_main(text: String?, pattern: String?, loc: Int): Int { - // Check for null inputs. - var loc = loc - if (text == null || pattern == null) { - throw IllegalArgumentException("Null inputs. (match_main)") - } - - loc = max(0.0, min(loc.toDouble(), text.length.toDouble())).toInt() - if ((text == pattern)) { - // Shortcut (potentially not guaranteed by the algorithm) - return 0 - } else if (text.length == 0) { - // Nothing to match. - return -1 - } else if ((loc + pattern.length <= text.length - && (text.substring(loc, loc + pattern.length) == pattern)) - ) { - // Perfect match at the perfect spot! (Includes case of null pattern) - return loc - } else { - // Do a fuzzy compare. - return match_bitap(text, pattern, loc) - } - } - - /** - * Locate the best instance of 'pattern' in 'text' near 'loc' using the - * Bitap algorithm. Returns -1 if no match found. - * @param text The text to search. - * @param pattern The pattern to search for. - * @param loc The location to search around. - * @return Best match index or -1. - */ - private fun match_bitap(text: String, pattern: String, loc: Int): Int { - assert(Match_MaxBits.toInt() == 0 || pattern.length <= Match_MaxBits) { "Pattern too long for this application." } - // Initialise the alphabet. - val s = match_alphabet(pattern) - - // Highest score beyond which we give up. - var score_threshold = Match_Threshold.toDouble() - // Is there a nearby exact match? (speedup) - var best_loc = text.indexOf(pattern, loc) - if (best_loc != -1) { - score_threshold = min( - match_bitapScore(0, best_loc, loc, pattern), - score_threshold - ) - // What about in the other direction? (speedup) - best_loc = text.lastIndexOf(pattern, loc + pattern.length) - if (best_loc != -1) { - score_threshold = min( - match_bitapScore(0, best_loc, loc, pattern), - score_threshold - ) + while (count_insert-- > 0) { + pointer.previous() + pointer.remove() } - } - - // Initialise the bit arrays. - val matchmask = 1 shl (pattern.length - 1) - best_loc = -1 - - var bin_min: Int - var bin_mid: Int - var bin_max = pattern.length + text.length - // Empty initialization added to appease Java compiler. - var last_rd = IntArray(0) - for (d in 0 until pattern.length) { - // Scan for the best match; each iteration allows for one more error. - // Run a binary search to determine how far from 'loc' we can stray at - // this error level. - bin_min = 0 - bin_mid = bin_max - while (bin_min < bin_mid) { - if ((match_bitapScore(d, loc + bin_mid, loc, pattern) - <= score_threshold) - ) { - bin_min = bin_mid + if (both_types) { + // Factor out any common prefixies. + commonlength = diff_commonPrefix(text_insert, text_delete) + if (commonlength != 0) { + if (pointer.hasPrevious()) { + thisDiff = pointer.previous() + assert( + thisDiff.operation == Operation.EQUAL + ) { "Previous diff should have been an equality." } + thisDiff.text += text_insert!!.substring(0, commonlength) + pointer.next() } else { - bin_max = bin_mid + pointer.add( + Diff( + Operation.EQUAL, + text_insert!!.substring(0, commonlength) + ) + ) } - bin_mid = (bin_max - bin_min) / 2 + bin_min + text_insert = text_insert.substring(commonlength) + text_delete = text_delete!!.substring(commonlength) + } + // Factor out any common suffixies. + commonlength = diff_commonSuffix(text_insert, text_delete) + if (commonlength != 0) { + thisDiff = pointer.next() + thisDiff.text = text_insert!!.substring( + text_insert.length + - commonlength + ) + thisDiff.text + text_insert = text_insert.substring( + 0, text_insert.length + - commonlength + ) + text_delete = text_delete!!.substring( + 0, text_delete.length + - commonlength + ) + pointer.previous() + } } - // Use the result from this iteration as the maximum for the next. - bin_max = bin_mid - var start = max(1.0, (loc - bin_mid + 1).toDouble()).toInt() - val finish = (min((loc + bin_mid).toDouble(), text.length.toDouble()) + pattern.length).toInt() - - val rd = IntArray(finish + 2) - rd[finish + 1] = (1 shl d) - 1 - var j = finish - while (j >= start) { - var charMatch: Int - if (text.length <= j - 1 || !s.containsKey(text[j - 1])) { - // Out of range. - charMatch = 0 - } else { - charMatch = (s[text[j - 1]])!! - } - if (d == 0) { - // First pass: exact match. - rd[j] = ((rd[j + 1] shl 1) or 1) and charMatch - } else { - // Subsequent passes: fuzzy match. - rd[j] = ((((rd[j + 1] shl 1) or 1) and charMatch) - or (((last_rd[j + 1] or last_rd[j]) shl 1) or 1) or last_rd[j + 1]) - } - if ((rd[j] and matchmask) != 0) { - val score = match_bitapScore(d, j - 1, loc, pattern) - // This match will almost certainly be better than any existing - // match. But check anyway. - if (score <= score_threshold) { - // Told you so. - score_threshold = score - best_loc = j - 1 - if (best_loc > loc) { - // When passing loc, don't exceed our current distance from loc. - start = max(1.0, (2 * loc - best_loc).toDouble()).toInt() - } else { - // Already passed loc, downhill from here on in. - break - } - } - } - j-- + // Insert the merged records. + if (text_delete!!.length != 0) { + pointer.add(Diff(Operation.DELETE, text_delete)) } - if (match_bitapScore(d + 1, loc, loc, pattern) > score_threshold) { - // No hope for a (better) match at greater error levels. - break + if (text_insert!!.length != 0) { + pointer.add(Diff(Operation.INSERT, text_insert)) } - last_rd = rd - } - return best_loc + // Step forward to the equality. + thisDiff = if (pointer.hasNext()) pointer.next() else null + } else if (prevEqual != null) { + // Merge this equality with the previous one. + prevEqual.text += thisDiff.text + pointer.remove() + thisDiff = pointer.previous() + pointer.next() // Forward direction + } + count_insert = 0 + count_delete = 0 + text_delete = "" + text_insert = "" + prevEqual = thisDiff + } + + null -> TODO() + } + thisDiff = if (pointer.hasNext()) pointer.next() else null } - - /** - * Compute and return the score for a match with e errors and x location. - * @param e Number of errors in match. - * @param x Location of match. - * @param loc Expected location of match. - * @param pattern Pattern being sought. - * @return Overall score for match (0.0 = good, 1.0 = bad). - */ - private fun match_bitapScore(e: Int, x: Int, loc: Int, pattern: String): Double { - val accuracy = e.toFloat() / pattern.length - val proximity = abs((loc - x).toDouble()).toInt() - if (Match_Distance == 0) { - // Dodge divide by zero error. - return if (proximity == 0) accuracy.toDouble() else 1.0 - } - return (accuracy + (proximity / Match_Distance.toFloat())).toDouble() + if (diffs.last.text!!.length == 0) { + diffs.removeLast() // Remove the dummy entry at the end. } - /** - * Initialise the alphabet for the Bitap algorithm. - * @param pattern The text to encode. - * @return Hash of character locations. - */ - private fun match_alphabet(pattern: String): Map { - val s: MutableMap = HashMap() - val char_pattern = pattern.toCharArray() - for (c: Char in char_pattern) { - s[c] = 0 - } - var i = 0 - for (c: Char in char_pattern) { - s[c] = s.get(c)!! or (1 shl (pattern.length - i - 1)) - i++ - } - return s + /* + * Second pass: look for single edits surrounded on both sides by equalities + * which can be shifted sideways to eliminate an equality. + * e.g: ABAC -> ABAC + */ + var changes = false + // Create a new iterator at the start. + // (As opposed to walking the current one back.) + pointer = diffs.listIterator() + var prevDiff = if (pointer.hasNext()) pointer.next() else null + thisDiff = if (pointer.hasNext()) pointer.next() else null + var nextDiff = if (pointer.hasNext()) pointer.next() else null + // Intentionally ignore the first and last element (don't need checking). + while (nextDiff != null) { + if (prevDiff!!.operation == Operation.EQUAL && + nextDiff.operation == Operation.EQUAL + ) { + // This is a single edit surrounded by equalities. + if (thisDiff!!.text!!.endsWith(prevDiff.text!!)) { + // Shift the edit over the previous equality. + thisDiff.text = (prevDiff.text + + thisDiff.text!!.substring( + 0, thisDiff.text!!.length + - prevDiff.text!!.length + )) + nextDiff.text = prevDiff.text + nextDiff.text + pointer.previous() // Walk past nextDiff. + pointer.previous() // Walk past thisDiff. + pointer.previous() // Walk past prevDiff. + pointer.remove() // Delete prevDiff. + pointer.next() // Walk past thisDiff. + thisDiff = pointer.next() // Walk past nextDiff. + nextDiff = if (pointer.hasNext()) pointer.next() else null + changes = true + } else if (thisDiff.text!!.startsWith(nextDiff.text!!)) { + // Shift the edit over the next equality. + prevDiff.text += nextDiff.text + thisDiff.text = (thisDiff.text!!.substring(nextDiff.text!!.length) + + nextDiff.text) + pointer.remove() // Delete nextDiff. + nextDiff = if (pointer.hasNext()) pointer.next() else null + changes = true + } + } + prevDiff = thisDiff + thisDiff = nextDiff + nextDiff = if (pointer.hasNext()) pointer.next() else null + } + // If shifts were made, the diff needs reordering and another shift sweep. + if (changes) { + diff_cleanupMerge(diffs) + } + } + + /** + * loc is a location in text1, compute and return the equivalent location in + * text2. + * e.g. "The cat" vs "The big cat", 1->1, 5->8 + * @param diffs List of Diff objects. + * @param loc Location within text1. + * @return Location within text2. + */ + private fun diff_xIndex(diffs: List, loc: Int): Int { + var chars1 = 0 + var chars2 = 0 + var last_chars1 = 0 + var last_chars2 = 0 + var lastDiff: Diff? = null + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.INSERT) { + // Equality or deletion. + chars1 += aDiff.text!!.length + } + if (aDiff.operation != Operation.DELETE) { + // Equality or insertion. + chars2 += aDiff.text!!.length + } + if (chars1 > loc) { + // Overshot the location. + lastDiff = aDiff + break + } + last_chars1 = chars1 + last_chars2 = chars2 + } + if (lastDiff != null && lastDiff.operation == Operation.DELETE) { + // The location was deleted. + return last_chars2 + } + // Add the remaining character length. + return last_chars2 + (loc - last_chars1) + } + + /** + * Compute and return the source text (all equalities and deletions). + * @param diffs List of Diff objects. + * @return Source text. + */ + private fun diff_text1(diffs: List): String { + val text = StringBuilder() + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.INSERT) { + text.append(aDiff.text) + } + } + return text.toString() + } + + /** + * Compute and return the destination text (all equalities and insertions). + * @param diffs List of Diff objects. + * @return Destination text. + */ + private fun diff_text2(diffs: List): String { + val text = StringBuilder() + for (aDiff: Diff in diffs) { + if (aDiff.operation != Operation.DELETE) { + text.append(aDiff.text) + } + } + return text.toString() + } + + /** + * Compute the Levenshtein distance; the number of inserted, deleted or + * substituted characters. + * @param diffs List of Diff objects. + * @return Number of changes. + */ + private fun diff_levenshtein(diffs: List): Int { + var levenshtein = 0 + var insertions = 0 + var deletions = 0 + for (aDiff: Diff in diffs) { + when (aDiff.operation) { + Operation.INSERT -> insertions += aDiff.text!!.length + Operation.DELETE -> deletions += aDiff.text!!.length + Operation.EQUAL -> { + // A deletion and an insertion is one substitution. + levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() + insertions = 0 + deletions = 0 + } + + null -> TODO() + } + } + levenshtein += (max(insertions.toDouble(), deletions.toDouble())).toInt() + return levenshtein + } + + + // MATCH FUNCTIONS + /** + * Locate the best instance of 'pattern' in 'text' near 'loc'. + * Returns -1 if no match found. + * @param text The text to search. + * @param pattern The pattern to search for. + * @param loc The location to search around. + * @return Best match index or -1. + */ + private fun match_main(text: String?, pattern: String?, loc: Int): Int { + // Check for null inputs. + var loc = loc + if (text == null || pattern == null) { + throw IllegalArgumentException("Null inputs. (match_main)") } - - // PATCH FUNCTIONS - /** - * Increase the context until it is unique, - * but don't let the pattern expand beyond Match_MaxBits. - * @param patch The patch to grow. - * @param text Source text. - */ - private fun patch_addContext(patch: Patch, text: String) { - if (text.length == 0) { - return - } - var pattern = text.substring(patch.start2, patch.start2 + patch.length1) - var padding = 0 - - // Look for the first and last matches of pattern in text. If two different - // matches are found, increase the pattern length. - while ((text.indexOf(pattern) != text.lastIndexOf(pattern) - && pattern.length < Match_MaxBits - Patch_Margin - Patch_Margin) - ) { - padding += Patch_Margin.toInt() - pattern = text.substring( - max(0.0, (patch.start2 - padding).toDouble()).toInt(), - min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() - ) - } - // Add one chunk for good luck. - padding += Patch_Margin.toInt() - - // Add the prefix. - val prefix = text.substring( - max(0.0, (patch.start2 - padding).toDouble()).toInt(), - patch.start2 - ) - if (prefix.length != 0) { - patch.diffs.addFirst(Diff(Operation.EQUAL, prefix)) - } - // Add the suffix. - val suffix = text.substring( - patch.start2 + patch.length1, - min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() + loc = max(0.0, min(loc.toDouble(), text.length.toDouble())).toInt() + if ((text == pattern)) { + // Shortcut (potentially not guaranteed by the algorithm) + return 0 + } else if (text.length == 0) { + // Nothing to match. + return -1 + } else if ((loc + pattern.length <= text.length + && (text.substring(loc, loc + pattern.length) == pattern)) + ) { + // Perfect match at the perfect spot! (Includes case of null pattern) + return loc + } else { + // Do a fuzzy compare. + return match_bitap(text, pattern, loc) + } + } + + /** + * Locate the best instance of 'pattern' in 'text' near 'loc' using the + * Bitap algorithm. Returns -1 if no match found. + * @param text The text to search. + * @param pattern The pattern to search for. + * @param loc The location to search around. + * @return Best match index or -1. + */ + private fun match_bitap(text: String, pattern: String, loc: Int): Int { + assert(Match_MaxBits.toInt() == 0 || pattern.length <= Match_MaxBits) { "Pattern too long for this application." } + // Initialise the alphabet. + val s = match_alphabet(pattern) + + // Highest score beyond which we give up. + var score_threshold = Match_Threshold.toDouble() + // Is there a nearby exact match? (speedup) + var best_loc = text.indexOf(pattern, loc) + if (best_loc != -1) { + score_threshold = min( + match_bitapScore(0, best_loc, loc, pattern), + score_threshold + ) + // What about in the other direction? (speedup) + best_loc = text.lastIndexOf(pattern, loc + pattern.length) + if (best_loc != -1) { + score_threshold = min( + match_bitapScore(0, best_loc, loc, pattern), + score_threshold ) - if (suffix.length != 0) { - patch.diffs.addLast(Diff(Operation.EQUAL, suffix)) - } - - // Roll back the start points. - patch.start1 -= prefix.length - patch.start2 -= prefix.length - // Extend the lengths. - patch.length1 += prefix.length + suffix.length - patch.length2 += prefix.length + suffix.length + } } - /** - * Compute a list of patches to turn text1 into text2. - * A set of diffs will be computed. - * @param text1 Old text. - * @param text2 New text. - * @return LinkedList of Patch objects. - */ - fun patch_make(text1: String?, text2: String?): LinkedList { - if (text1 == null || text2 == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") - } - // No diffs provided, compute our own. - val diffs = diff_main(text1, text2, true) - if (diffs.size > 2) { - diff_cleanupSemantic(diffs) - diff_cleanupEfficiency(diffs) + // Initialise the bit arrays. + val matchmask = 1 shl (pattern.length - 1) + best_loc = -1 + + var bin_min: Int + var bin_mid: Int + var bin_max = pattern.length + text.length + // Empty initialization added to appease Java compiler. + var last_rd = IntArray(0) + for (d in 0 until pattern.length) { + // Scan for the best match; each iteration allows for one more error. + // Run a binary search to determine how far from 'loc' we can stray at + // this error level. + bin_min = 0 + bin_mid = bin_max + while (bin_min < bin_mid) { + if ((match_bitapScore(d, loc + bin_mid, loc, pattern) + <= score_threshold) + ) { + bin_min = bin_mid + } else { + bin_max = bin_mid + } + bin_mid = (bin_max - bin_min) / 2 + bin_min + } + // Use the result from this iteration as the maximum for the next. + bin_max = bin_mid + var start = max(1.0, (loc - bin_mid + 1).toDouble()).toInt() + val finish = (min((loc + bin_mid).toDouble(), text.length.toDouble()) + pattern.length).toInt() + + val rd = IntArray(finish + 2) + rd[finish + 1] = (1 shl d) - 1 + var j = finish + while (j >= start) { + var charMatch: Int + if (text.length <= j - 1 || !s.containsKey(text[j - 1])) { + // Out of range. + charMatch = 0 + } else { + charMatch = (s[text[j - 1]])!! } - return patch_make(text1, diffs) + if (d == 0) { + // First pass: exact match. + rd[j] = ((rd[j + 1] shl 1) or 1) and charMatch + } else { + // Subsequent passes: fuzzy match. + rd[j] = ((((rd[j + 1] shl 1) or 1) and charMatch) + or (((last_rd[j + 1] or last_rd[j]) shl 1) or 1) or last_rd[j + 1]) + } + if ((rd[j] and matchmask) != 0) { + val score = match_bitapScore(d, j - 1, loc, pattern) + // This match will almost certainly be better than any existing + // match. But check anyway. + if (score <= score_threshold) { + // Told you so. + score_threshold = score + best_loc = j - 1 + if (best_loc > loc) { + // When passing loc, don't exceed our current distance from loc. + start = max(1.0, (2 * loc - best_loc).toDouble()).toInt() + } else { + // Already passed loc, downhill from here on in. + break + } + } + } + j-- + } + if (match_bitapScore(d + 1, loc, loc, pattern) > score_threshold) { + // No hope for a (better) match at greater error levels. + break + } + last_rd = rd + } + return best_loc + } + + /** + * Compute and return the score for a match with e errors and x location. + * @param e Number of errors in match. + * @param x Location of match. + * @param loc Expected location of match. + * @param pattern Pattern being sought. + * @return Overall score for match (0.0 = good, 1.0 = bad). + */ + private fun match_bitapScore(e: Int, x: Int, loc: Int, pattern: String): Double { + val accuracy = e.toFloat() / pattern.length + val proximity = abs((loc - x).toDouble()).toInt() + if (Match_Distance == 0) { + // Dodge divide by zero error. + return if (proximity == 0) accuracy.toDouble() else 1.0 } + return (accuracy + (proximity / Match_Distance.toFloat())).toDouble() + } + + /** + * Initialise the alphabet for the Bitap algorithm. + * @param pattern The text to encode. + * @return Hash of character locations. + */ + private fun match_alphabet(pattern: String): Map { + val s: MutableMap = HashMap() + val char_pattern = pattern.toCharArray() + for (c: Char in char_pattern) { + s[c] = 0 + } + var i = 0 + for (c: Char in char_pattern) { + s[c] = s.get(c)!! or (1 shl (pattern.length - i - 1)) + i++ + } + return s + } + + + // PATCH FUNCTIONS + /** + * Increase the context until it is unique, + * but don't let the pattern expand beyond Match_MaxBits. + * @param patch The patch to grow. + * @param text Source text. + */ + private fun patch_addContext(patch: Patch, text: String) { + if (text.length == 0) { + return + } + var pattern = text.substring(patch.start2, patch.start2 + patch.length1) + var padding = 0 - /** - * Compute a list of patches to turn text1 into text2. - * text1 will be derived from the provided diffs. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - fun patch_make(diffs: LinkedList?): LinkedList { - if (diffs == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") - } - // No origin string provided, compute our own. - val text1 = diff_text1(diffs) - return patch_make(text1, diffs) + // Look for the first and last matches of pattern in text. If two different + // matches are found, increase the pattern length. + while ((text.indexOf(pattern) != text.lastIndexOf(pattern) + && pattern.length < Match_MaxBits - Patch_Margin - Patch_Margin) + ) { + padding += Patch_Margin.toInt() + pattern = text.substring( + max(0.0, (patch.start2 - padding).toDouble()).toInt(), + min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() + ) } + // Add one chunk for good luck. + padding += Patch_Margin.toInt() - /** - * Compute a list of patches to turn text1 into text2. - * text2 is ignored, diffs are the delta between text1 and text2. - * @param text1 Old text - * @param text2 Ignored. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - @Deprecated("Prefer patch_make(String text1, LinkedList diffs).") - fun patch_make( - text1: String?, text2: String?, - diffs: LinkedList? - ): LinkedList { - return patch_make(text1, diffs) + // Add the prefix. + val prefix = text.substring( + max(0.0, (patch.start2 - padding).toDouble()).toInt(), + patch.start2 + ) + if (prefix.length != 0) { + patch.diffs.addFirst(Diff(Operation.EQUAL, prefix)) + } + // Add the suffix. + val suffix = text.substring( + patch.start2 + patch.length1, + min(text.length.toDouble(), (patch.start2 + patch.length1 + padding).toDouble()).toInt() + ) + if (suffix.length != 0) { + patch.diffs.addLast(Diff(Operation.EQUAL, suffix)) } - /** - * Compute a list of patches to turn text1 into text2. - * text2 is not provided, diffs are the delta between text1 and text2. - * @param text1 Old text. - * @param diffs Array of Diff objects for text1 to text2. - * @return LinkedList of Patch objects. - */ - fun patch_make(text1: String?, diffs: LinkedList?): LinkedList { - if (text1 == null || diffs == null) { - throw IllegalArgumentException("Null inputs. (patch_make)") - } + // Roll back the start points. + patch.start1 -= prefix.length + patch.start2 -= prefix.length + // Extend the lengths. + patch.length1 += prefix.length + suffix.length + patch.length2 += prefix.length + suffix.length + } + + /** + * Compute a list of patches to turn text1 into text2. + * A set of diffs will be computed. + * @param text1 Old text. + * @param text2 New text. + * @return LinkedList of Patch objects. + */ + fun patch_make(text1: String?, text2: String?): LinkedList { + if (text1 == null || text2 == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } + // No diffs provided, compute our own. + val diffs = diff_main(text1, text2, true) + if (diffs.size > 2) { + diff_cleanupSemantic(diffs) + diff_cleanupEfficiency(diffs) + } + return patch_make(text1, diffs) + } + + /** + * Compute a list of patches to turn text1 into text2. + * text1 will be derived from the provided diffs. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + fun patch_make(diffs: LinkedList?): LinkedList { + if (diffs == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } + // No origin string provided, compute our own. + val text1 = diff_text1(diffs) + return patch_make(text1, diffs) + } + + /** + * Compute a list of patches to turn text1 into text2. + * text2 is ignored, diffs are the delta between text1 and text2. + * @param text1 Old text + * @param text2 Ignored. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + @Deprecated("Prefer patch_make(String text1, LinkedList diffs).") + fun patch_make( + text1: String?, text2: String?, + diffs: LinkedList? + ): LinkedList { + return patch_make(text1, diffs) + } + + /** + * Compute a list of patches to turn text1 into text2. + * text2 is not provided, diffs are the delta between text1 and text2. + * @param text1 Old text. + * @param diffs Array of Diff objects for text1 to text2. + * @return LinkedList of Patch objects. + */ + fun patch_make(text1: String?, diffs: LinkedList?): LinkedList { + if (text1 == null || diffs == null) { + throw IllegalArgumentException("Null inputs. (patch_make)") + } - val patches = LinkedList() - if (diffs.isEmpty()) { - return patches // Get rid of the null case. - } - var patch = Patch() - var char_count1 = 0 // Number of characters into the text1 string. - var char_count2 = 0 // Number of characters into the text2 string. - // Start with text1 (prepatch_text) and apply the diffs until we arrive at - // text2 (postpatch_text). We recreate the patches one by one to determine - // context info. - var prepatch_text: String = text1 - var postpatch_text: String = text1 - for (aDiff: Diff in diffs) { - if (patch.diffs.isEmpty() && aDiff.operation != Operation.EQUAL) { - // A new patch starts here. - patch.start1 = char_count1 - patch.start2 = char_count2 + val patches = LinkedList() + if (diffs.isEmpty()) { + return patches // Get rid of the null case. + } + var patch = Patch() + var char_count1 = 0 // Number of characters into the text1 string. + var char_count2 = 0 // Number of characters into the text2 string. + // Start with text1 (prepatch_text) and apply the diffs until we arrive at + // text2 (postpatch_text). We recreate the patches one by one to determine + // context info. + var prepatch_text: String = text1 + var postpatch_text: String = text1 + for (aDiff: Diff in diffs) { + if (patch.diffs.isEmpty() && aDiff.operation != Operation.EQUAL) { + // A new patch starts here. + patch.start1 = char_count1 + patch.start2 = char_count2 + } + + when (aDiff.operation) { + Operation.INSERT -> { + patch.diffs.add(aDiff) + patch.length2 += aDiff.text!!.length + postpatch_text = (postpatch_text.substring(0, char_count2) + + aDiff.text + postpatch_text.substring(char_count2)) + } + + Operation.DELETE -> { + patch.length1 += aDiff.text!!.length + patch.diffs.add(aDiff) + postpatch_text = (postpatch_text.substring(0, char_count2) + + postpatch_text.substring(char_count2 + aDiff.text!!.length)) + } + + Operation.EQUAL -> { + if ((aDiff.text!!.length <= 2 * Patch_Margin + ) && !patch.diffs.isEmpty() && (aDiff !== diffs.last) + ) { + // Small equality inside a patch. + patch.diffs.add(aDiff) + patch.length1 += aDiff.text!!.length + patch.length2 += aDiff.text!!.length + } + + if (aDiff.text!!.length >= 2 * Patch_Margin && !patch.diffs.isEmpty()) { + // Time for a new patch. + if (!patch.diffs.isEmpty()) { + patch_addContext(patch, prepatch_text) + patches.add(patch) + patch = Patch() + // Unlike Unidiff, our patch lists have a rolling context. + // https://github.com/google/diff-match-patch/wiki/Unidiff + // Update prepatch text & pos to reflect the application of the + // just completed patch. + prepatch_text = postpatch_text + char_count1 = char_count2 } + } + } + + null -> TODO() + } + // Update the current character count. + if (aDiff.operation != Operation.INSERT) { + char_count1 += aDiff.text!!.length + } + if (aDiff.operation != Operation.DELETE) { + char_count2 += aDiff.text!!.length + } + } + // Pick up the leftover patch if not empty. + if (!patch.diffs.isEmpty()) { + patch_addContext(patch, prepatch_text) + patches.add(patch) + } - when (aDiff.operation) { - Operation.INSERT -> { - patch.diffs.add(aDiff) - patch.length2 += aDiff.text!!.length - postpatch_text = (postpatch_text.substring(0, char_count2) - + aDiff.text + postpatch_text.substring(char_count2)) - } - - Operation.DELETE -> { - patch.length1 += aDiff.text!!.length - patch.diffs.add(aDiff) - postpatch_text = (postpatch_text.substring(0, char_count2) - + postpatch_text.substring(char_count2 + aDiff.text!!.length)) - } + return patches + } + + /** + * Given an array of patches, return another array that is identical. + * @param patches Array of Patch objects. + * @return Array of Patch objects. + */ + private fun patch_deepCopy(patches: LinkedList): LinkedList { + val patchesCopy = LinkedList() + for (aPatch: Patch in patches) { + val patchCopy = Patch() + for (aDiff: Diff in aPatch.diffs) { + val diffCopy = Diff(aDiff.operation, aDiff.text) + patchCopy.diffs.add(diffCopy) + } + patchCopy.start1 = aPatch.start1 + patchCopy.start2 = aPatch.start2 + patchCopy.length1 = aPatch.length1 + patchCopy.length2 = aPatch.length2 + patchesCopy.add(patchCopy) + } + return patchesCopy + } + + /** + * Merge a set of patches onto the text. Return a patched text, as well + * as an array of true/false values indicating which patches were applied. + * @param patches Array of Patch objects + * @param text Old text. + * @return Two element Object array, containing the new text and an array of + * boolean values. + */ + fun patch_apply(patches: LinkedList, text: String): Array { + var patches = patches + var text = text + if (patches.isEmpty()) { + return arrayOf(text, BooleanArray(0)) + } - Operation.EQUAL -> { - if ((aDiff.text!!.length <= 2 * Patch_Margin - ) && !patch.diffs.isEmpty() && (aDiff !== diffs.last) - ) { - // Small equality inside a patch. - patch.diffs.add(aDiff) - patch.length1 += aDiff.text!!.length - patch.length2 += aDiff.text!!.length - } - - if (aDiff.text!!.length >= 2 * Patch_Margin && !patch.diffs.isEmpty()) { - // Time for a new patch. - if (!patch.diffs.isEmpty()) { - patch_addContext(patch, prepatch_text) - patches.add(patch) - patch = Patch() - // Unlike Unidiff, our patch lists have a rolling context. - // https://github.com/google/diff-match-patch/wiki/Unidiff - // Update prepatch text & pos to reflect the application of the - // just completed patch. - prepatch_text = postpatch_text - char_count1 = char_count2 - } - } + // Deep copy the patches so that no changes are made to originals. + patches = patch_deepCopy(patches) + + val nullPadding = patch_addPadding(patches) + text = nullPadding + text + nullPadding + patch_splitMax(patches) + + var x = 0 + // delta keeps track of the offset between the expected and actual location + // of the previous patch. If there are patches expected at positions 10 and + // 20, but the first patch was found at 12, delta is 2 and the second patch + // has an effective expected position of 22. + var delta = 0 + val results = BooleanArray(patches.size) + for (aPatch: Patch in patches) { + val expected_loc = aPatch.start2 + delta + val text1 = diff_text1(aPatch.diffs) + var start_loc: Int + var end_loc = -1 + if (text1.length > this.Match_MaxBits) { + // patch_splitMax will only provide an oversized pattern in the case of + // a monster delete. + start_loc = match_main( + text, + text1.substring(0, Match_MaxBits.toInt()), expected_loc + ) + if (start_loc != -1) { + end_loc = match_main( + text, + text1.substring(text1.length - this.Match_MaxBits), + expected_loc + text1.length - this.Match_MaxBits + ) + if (end_loc == -1 || start_loc >= end_loc) { + // Can't find valid trailing context. Drop this patch. + start_loc = -1 + } + } + } else { + start_loc = match_main(text, text1, expected_loc) + } + if (start_loc == -1) { + // No match found. :( + results[x] = false + // Subtract the delta for this failed patch from subsequent patches. + delta -= aPatch.length2 - aPatch.length1 + } else { + // Found a match. :) + results[x] = true + delta = start_loc - expected_loc + var text2: String + if (end_loc == -1) { + text2 = text.substring( + start_loc, + min((start_loc + text1.length).toDouble(), text.length.toDouble()).toInt() + ) + } else { + text2 = text.substring( + start_loc, + min((end_loc + this.Match_MaxBits).toDouble(), text.length.toDouble()).toInt() + ) + } + if ((text1 == text2)) { + // Perfect match, just shove the replacement text in. + text = (text.substring(0, start_loc) + diff_text2(aPatch.diffs) + + text.substring(start_loc + text1.length)) + } else { + // Imperfect match. Run a diff to get a framework of equivalent + // indices. + val diffs = diff_main(text1, text2, false) + if ((text1.length > this.Match_MaxBits + && diff_levenshtein(diffs) / text1.length.toFloat() + > this.Patch_DeleteThreshold) + ) { + // The end points match, but the content is unacceptably bad. + results[x] = false + } else { + diff_cleanupSemanticLossless(diffs) + var index1 = 0 + for (aDiff: Diff in aPatch.diffs) { + if (aDiff.operation != Operation.EQUAL) { + val index2 = diff_xIndex(diffs, index1) + if (aDiff.operation == Operation.INSERT) { + // Insertion + text = (text.substring(0, start_loc + index2) + aDiff.text + + text.substring(start_loc + index2)) + } else if (aDiff.operation == Operation.DELETE) { + // Deletion + text = (text.substring(0, start_loc + index2) + + text.substring( + start_loc + diff_xIndex( + diffs, + index1 + aDiff.text!!.length + ) + )) } - - null -> TODO() - } - // Update the current character count. - if (aDiff.operation != Operation.INSERT) { - char_count1 += aDiff.text!!.length - } - if (aDiff.operation != Operation.DELETE) { - char_count2 += aDiff.text!!.length + } + if (aDiff.operation != Operation.DELETE) { + index1 += aDiff.text!!.length + } } + } } - // Pick up the leftover patch if not empty. - if (!patch.diffs.isEmpty()) { - patch_addContext(patch, prepatch_text) - patches.add(patch) - } + } + x++ + } + // Strip the padding off. + text = text.substring( + nullPadding.length, (text.length + - nullPadding.length) + ) + return arrayOf(text, results) + } + + /** + * Add some padding on text start and end so that edges can match something. + * Intended to be called only from within patch_apply. + * @param patches Array of Patch objects. + * @return The padding string added to each side. + */ + private fun patch_addPadding(patches: LinkedList): String { + val paddingLength = this.Patch_Margin + var nullPadding = "" + for (x in 1..paddingLength) { + nullPadding += (Char(x.toUShort())).toString() + } - return patches + // Bump all the patches forward. + for (aPatch: Patch in patches) { + aPatch.start1 += paddingLength.toInt() + aPatch.start2 += paddingLength.toInt() } - /** - * Given an array of patches, return another array that is identical. - * @param patches Array of Patch objects. - * @return Array of Patch objects. - */ - private fun patch_deepCopy(patches: LinkedList): LinkedList { - val patchesCopy = LinkedList() - for (aPatch: Patch in patches) { - val patchCopy = Patch() - for (aDiff: Diff in aPatch.diffs) { - val diffCopy = Diff(aDiff.operation, aDiff.text) - patchCopy.diffs.add(diffCopy) - } - patchCopy.start1 = aPatch.start1 - patchCopy.start2 = aPatch.start2 - patchCopy.length1 = aPatch.length1 - patchCopy.length2 = aPatch.length2 - patchesCopy.add(patchCopy) - } - return patchesCopy + // Add some padding on start of first diff. + var patch = patches.first + var diffs = patch.diffs + if (diffs.isEmpty() || diffs.first.operation != Operation.EQUAL) { + // Add nullPadding equality. + diffs.addFirst(Diff(Operation.EQUAL, nullPadding)) + patch.start1 -= paddingLength.toInt() // Should be 0. + patch.start2 -= paddingLength.toInt() // Should be 0. + patch.length1 += paddingLength.toInt() + patch.length2 += paddingLength.toInt() + } else if (paddingLength > diffs.first.text!!.length) { + // Grow first equality. + val firstDiff = diffs.first + val extraLength = paddingLength - firstDiff.text!!.length + firstDiff.text = (nullPadding.substring(firstDiff.text!!.length) + + firstDiff.text) + patch.start1 -= extraLength + patch.start2 -= extraLength + patch.length1 += extraLength + patch.length2 += extraLength } - /** - * Merge a set of patches onto the text. Return a patched text, as well - * as an array of true/false values indicating which patches were applied. - * @param patches Array of Patch objects - * @param text Old text. - * @return Two element Object array, containing the new text and an array of - * boolean values. - */ - fun patch_apply(patches: LinkedList, text: String): Array { - var patches = patches - var text = text - if (patches.isEmpty()) { - return arrayOf(text, BooleanArray(0)) - } + // Add some padding on end of last diff. + patch = patches.last + diffs = patch.diffs + if (diffs.isEmpty() || diffs.last.operation != Operation.EQUAL) { + // Add nullPadding equality. + diffs.addLast(Diff(Operation.EQUAL, nullPadding)) + patch.length1 += paddingLength.toInt() + patch.length2 += paddingLength.toInt() + } else if (paddingLength > diffs.last.text!!.length) { + // Grow last equality. + val lastDiff = diffs.last + val extraLength = paddingLength - lastDiff.text!!.length + lastDiff.text += nullPadding.substring(0, extraLength) + patch.length1 += extraLength + patch.length2 += extraLength + } - // Deep copy the patches so that no changes are made to originals. - patches = patch_deepCopy(patches) - - val nullPadding = patch_addPadding(patches) - text = nullPadding + text + nullPadding - patch_splitMax(patches) - - var x = 0 - // delta keeps track of the offset between the expected and actual location - // of the previous patch. If there are patches expected at positions 10 and - // 20, but the first patch was found at 12, delta is 2 and the second patch - // has an effective expected position of 22. - var delta = 0 - val results = BooleanArray(patches.size) - for (aPatch: Patch in patches) { - val expected_loc = aPatch.start2 + delta - val text1 = diff_text1(aPatch.diffs) - var start_loc: Int - var end_loc = -1 - if (text1.length > this.Match_MaxBits) { - // patch_splitMax will only provide an oversized pattern in the case of - // a monster delete. - start_loc = match_main( - text, - text1.substring(0, Match_MaxBits.toInt()), expected_loc - ) - if (start_loc != -1) { - end_loc = match_main( - text, - text1.substring(text1.length - this.Match_MaxBits), - expected_loc + text1.length - this.Match_MaxBits - ) - if (end_loc == -1 || start_loc >= end_loc) { - // Can't find valid trailing context. Drop this patch. - start_loc = -1 - } - } + return nullPadding + } + + /** + * Look through the patches and break up any which are longer than the + * maximum limit of the match algorithm. + * Intended to be called only from within patch_apply. + * @param patches LinkedList of Patch objects. + */ + private fun patch_splitMax(patches: LinkedList) { + val patch_size = Match_MaxBits + var precontext: String + var postcontext: String + var patch: Patch + var start1: Int + var start2: Int + var empty: Boolean + var diff_type: Operation + var diff_text: String + val pointer = patches.listIterator() + var bigpatch = if (pointer.hasNext()) pointer.next() else null + while (bigpatch != null) { + if (bigpatch.length1 <= Match_MaxBits) { + bigpatch = if (pointer.hasNext()) pointer.next() else null + continue + } + // Remove the big old patch. + pointer.remove() + start1 = bigpatch.start1 + start2 = bigpatch.start2 + precontext = "" + while (!bigpatch.diffs.isEmpty()) { + // Create one of several smaller patches. + patch = Patch() + empty = true + patch.start1 = start1 - precontext.length + patch.start2 = start2 - precontext.length + if (precontext.length != 0) { + patch.length2 = precontext.length + patch.length1 = patch.length2 + patch.diffs.add(Diff(Operation.EQUAL, precontext)) + } + while ((!bigpatch.diffs.isEmpty() + && patch.length1 < patch_size - Patch_Margin) + ) { + diff_type = bigpatch.diffs.first.operation!! + diff_text = bigpatch.diffs.first.text!! + if (diff_type == Operation.INSERT) { + // Insertions are harmless. + patch.length2 += diff_text.length + start2 += diff_text.length + patch.diffs.addLast(bigpatch.diffs.removeFirst()) + empty = false + } else if ((diff_type == Operation.DELETE) && (patch.diffs.size == 1 + ) && (patch.diffs.first.operation == Operation.EQUAL + ) && (diff_text.length > 2 * patch_size) + ) { + // This is a large deletion. Let it pass in one chunk. + patch.length1 += diff_text.length + start1 += diff_text.length + empty = false + patch.diffs.add(Diff(diff_type, diff_text)) + bigpatch.diffs.removeFirst() + } else { + // Deletion or equality. Only take as much as we can stomach. + diff_text = diff_text.substring( + 0, min( + diff_text.length.toDouble(), + (patch_size - patch.length1 - Patch_Margin).toDouble() + ).toInt() + ) + patch.length1 += diff_text.length + start1 += diff_text.length + if (diff_type == Operation.EQUAL) { + patch.length2 += diff_text.length + start2 += diff_text.length } else { - start_loc = match_main(text, text1, expected_loc) + empty = false } - if (start_loc == -1) { - // No match found. :( - results[x] = false - // Subtract the delta for this failed patch from subsequent patches. - delta -= aPatch.length2 - aPatch.length1 + patch.diffs.add(Diff(diff_type, diff_text)) + if ((diff_text == bigpatch.diffs.first.text)) { + bigpatch.diffs.removeFirst() } else { - // Found a match. :) - results[x] = true - delta = start_loc - expected_loc - var text2: String - if (end_loc == -1) { - text2 = text.substring( - start_loc, - min((start_loc + text1.length).toDouble(), text.length.toDouble()).toInt() - ) - } else { - text2 = text.substring( - start_loc, - min((end_loc + this.Match_MaxBits).toDouble(), text.length.toDouble()).toInt() - ) - } - if ((text1 == text2)) { - // Perfect match, just shove the replacement text in. - text = (text.substring(0, start_loc) + diff_text2(aPatch.diffs) - + text.substring(start_loc + text1.length)) - } else { - // Imperfect match. Run a diff to get a framework of equivalent - // indices. - val diffs = diff_main(text1, text2, false) - if ((text1.length > this.Match_MaxBits - && diff_levenshtein(diffs) / text1.length.toFloat() - > this.Patch_DeleteThreshold) - ) { - // The end points match, but the content is unacceptably bad. - results[x] = false - } else { - diff_cleanupSemanticLossless(diffs) - var index1 = 0 - for (aDiff: Diff in aPatch.diffs) { - if (aDiff.operation != Operation.EQUAL) { - val index2 = diff_xIndex(diffs, index1) - if (aDiff.operation == Operation.INSERT) { - // Insertion - text = (text.substring(0, start_loc + index2) + aDiff.text - + text.substring(start_loc + index2)) - } else if (aDiff.operation == Operation.DELETE) { - // Deletion - text = (text.substring(0, start_loc + index2) - + text.substring( - start_loc + diff_xIndex( - diffs, - index1 + aDiff.text!!.length - ) - )) - } - } - if (aDiff.operation != Operation.DELETE) { - index1 += aDiff.text!!.length - } - } - } - } + bigpatch.diffs.first.text = bigpatch.diffs.first.text!! + .substring(diff_text.length) } - x++ - } - // Strip the padding off. - text = text.substring( - nullPadding.length, (text.length - - nullPadding.length) + } + } + // Compute the head context for the next patch. + precontext = diff_text2(patch.diffs) + precontext = precontext.substring( + max( + 0.0, (precontext.length + - Patch_Margin).toDouble() + ).toInt() ) - return arrayOf(text, results) + // Append the end context for this patch. + if (diff_text1(bigpatch.diffs).length > Patch_Margin) { + postcontext = diff_text1(bigpatch.diffs).substring(0, Patch_Margin.toInt()) + } else { + postcontext = diff_text1(bigpatch.diffs) + } + if (postcontext.length != 0) { + patch.length1 += postcontext.length + patch.length2 += postcontext.length + if ((!patch.diffs.isEmpty() + && patch.diffs.last.operation == Operation.EQUAL) + ) { + patch.diffs.last.text += postcontext + } else { + patch.diffs.add(Diff(Operation.EQUAL, postcontext)) + } + } + if (!empty) { + pointer.add(patch) + } + } + bigpatch = if (pointer.hasNext()) pointer.next() else null } - - /** - * Add some padding on text start and end so that edges can match something. - * Intended to be called only from within patch_apply. - * @param patches Array of Patch objects. - * @return The padding string added to each side. - */ - private fun patch_addPadding(patches: LinkedList): String { - val paddingLength = this.Patch_Margin - var nullPadding = "" - for (x in 1..paddingLength) { - nullPadding += (Char(x.toUShort())).toString() - } - - // Bump all the patches forward. - for (aPatch: Patch in patches) { - aPatch.start1 += paddingLength.toInt() - aPatch.start2 += paddingLength.toInt() - } - - // Add some padding on start of first diff. - var patch = patches.first - var diffs = patch.diffs - if (diffs.isEmpty() || diffs.first.operation != Operation.EQUAL) { - // Add nullPadding equality. - diffs.addFirst(Diff(Operation.EQUAL, nullPadding)) - patch.start1 -= paddingLength.toInt() // Should be 0. - patch.start2 -= paddingLength.toInt() // Should be 0. - patch.length1 += paddingLength.toInt() - patch.length2 += paddingLength.toInt() - } else if (paddingLength > diffs.first.text!!.length) { - // Grow first equality. - val firstDiff = diffs.first - val extraLength = paddingLength - firstDiff.text!!.length - firstDiff.text = (nullPadding.substring(firstDiff.text!!.length) - + firstDiff.text) - patch.start1 -= extraLength - patch.start2 -= extraLength - patch.length1 += extraLength - patch.length2 += extraLength - } - - // Add some padding on end of last diff. - patch = patches.last - diffs = patch.diffs - if (diffs.isEmpty() || diffs.last.operation != Operation.EQUAL) { - // Add nullPadding equality. - diffs.addLast(Diff(Operation.EQUAL, nullPadding)) - patch.length1 += paddingLength.toInt() - patch.length2 += paddingLength.toInt() - } else if (paddingLength > diffs.last.text!!.length) { - // Grow last equality. - val lastDiff = diffs.last - val extraLength = paddingLength - lastDiff.text!!.length - lastDiff.text += nullPadding.substring(0, extraLength) - patch.length1 += extraLength - patch.length2 += extraLength + } + + /** + * Take a list of patches and return a textual representation. + * @param patches List of Patch objects. + * @return Text representation of patches. + */ + fun patch_toText(patches: List): String { + val text = StringBuilder() + for (aPatch: Patch? in patches) { + text.append(aPatch) + } + return text.toString() + } + + /** + * Parse a textual representation of patches and return a List of Patch + * objects. + * @param textline Text representation of patches. + * @return List of Patch objects. + * @throws IllegalArgumentException If invalid input. + */ + @Throws(IllegalArgumentException::class) + fun patch_fromText(textline: String): List { + val patches: MutableList = LinkedList() + if (textline.length == 0) { + return patches + } + val textList = Arrays.asList(*textline.split("\n".toRegex()).dropLastWhile { it.isEmpty() } + .toTypedArray()) + val text = LinkedList(textList) + var patch: Patch + val patchHeader = Pattern.compile("^@@ -(\\d+),?(\\d*) \\+(\\d+),?(\\d*) @@$") + var m: Matcher + var sign: Char + var line: String + while (!text.isEmpty()) { + m = patchHeader.matcher(text.first) + if (!m.matches()) { + throw IllegalArgumentException( + "Invalid patch string: " + text.first + ) + } + patch = Patch() + patches.add(patch) + patch.start1 = m.group(1).toInt() + if (m.group(2).length == 0) { + patch.start1-- + patch.length1 = 1 + } else if ((m.group(2) == "0")) { + patch.length1 = 0 + } else { + patch.start1-- + patch.length1 = m.group(2).toInt() + } + + patch.start2 = m.group(3).toInt() + if (m.group(4).length == 0) { + patch.start2-- + patch.length2 = 1 + } else if ((m.group(4) == "0")) { + patch.length2 = 0 + } else { + patch.start2-- + patch.length2 = m.group(4).toInt() + } + text.removeFirst() + + while (!text.isEmpty()) { + try { + sign = text.first[0] + } catch (e: IndexOutOfBoundsException) { + // Blank line? Whatever. + text.removeFirst() + continue + } + line = text.first.substring(1) + line = line.replace("+", "%2B") // decode would change all "+" to " " + try { + line = URLDecoder.decode(line, "UTF-8") + } catch (e: UnsupportedEncodingException) { + // Not likely on modern system. + throw Error("This system does not support UTF-8.", e) + } catch (e: IllegalArgumentException) { + // Malformed URI sequence. + throw IllegalArgumentException( + "Illegal escape in patch_fromText: $line", e + ) + } + if (sign == '-') { + // Deletion. + patch.diffs.add(Diff(Operation.DELETE, line)) + } else if (sign == '+') { + // Insertion. + patch.diffs.add(Diff(Operation.INSERT, line)) + } else if (sign == ' ') { + // Minor equality. + patch.diffs.add(Diff(Operation.EQUAL, line)) + } else if (sign == '@') { + // Start of next patch. + break + } else { + // WTF? + throw IllegalArgumentException( + "Invalid patch mode '$sign' in: $line" + ) } - - return nullPadding + text.removeFirst() + } } - + return patches + } + + + /** + * Class representing one diff operation. + */ + class Diff// Construct a diff with the specified operation and text. + /** + * Constructor. Initializes the diff with the provided values. + * @param operation One of INSERT, DELETE or EQUAL. + * @param text The text being applied. + */( /** - * Look through the patches and break up any which are longer than the - * maximum limit of the match algorithm. - * Intended to be called only from within patch_apply. - * @param patches LinkedList of Patch objects. + * One of: INSERT, DELETE or EQUAL. */ - private fun patch_splitMax(patches: LinkedList) { - val patch_size = Match_MaxBits - var precontext: String - var postcontext: String - var patch: Patch - var start1: Int - var start2: Int - var empty: Boolean - var diff_type: Operation - var diff_text: String - val pointer = patches.listIterator() - var bigpatch = if (pointer.hasNext()) pointer.next() else null - while (bigpatch != null) { - if (bigpatch.length1 <= Match_MaxBits) { - bigpatch = if (pointer.hasNext()) pointer.next() else null - continue - } - // Remove the big old patch. - pointer.remove() - start1 = bigpatch.start1 - start2 = bigpatch.start2 - precontext = "" - while (!bigpatch.diffs.isEmpty()) { - // Create one of several smaller patches. - patch = Patch() - empty = true - patch.start1 = start1 - precontext.length - patch.start2 = start2 - precontext.length - if (precontext.length != 0) { - patch.length2 = precontext.length - patch.length1 = patch.length2 - patch.diffs.add(Diff(Operation.EQUAL, precontext)) - } - while ((!bigpatch.diffs.isEmpty() - && patch.length1 < patch_size - Patch_Margin) - ) { - diff_type = bigpatch.diffs.first.operation!! - diff_text = bigpatch.diffs.first.text!! - if (diff_type == Operation.INSERT) { - // Insertions are harmless. - patch.length2 += diff_text.length - start2 += diff_text.length - patch.diffs.addLast(bigpatch.diffs.removeFirst()) - empty = false - } else if ((diff_type == Operation.DELETE) && (patch.diffs.size == 1 - ) && (patch.diffs.first.operation == Operation.EQUAL - ) && (diff_text.length > 2 * patch_size) - ) { - // This is a large deletion. Let it pass in one chunk. - patch.length1 += diff_text.length - start1 += diff_text.length - empty = false - patch.diffs.add(Diff(diff_type, diff_text)) - bigpatch.diffs.removeFirst() - } else { - // Deletion or equality. Only take as much as we can stomach. - diff_text = diff_text.substring( - 0, min( - diff_text.length.toDouble(), - (patch_size - patch.length1 - Patch_Margin).toDouble() - ).toInt() - ) - patch.length1 += diff_text.length - start1 += diff_text.length - if (diff_type == Operation.EQUAL) { - patch.length2 += diff_text.length - start2 += diff_text.length - } else { - empty = false - } - patch.diffs.add(Diff(diff_type, diff_text)) - if ((diff_text == bigpatch.diffs.first.text)) { - bigpatch.diffs.removeFirst() - } else { - bigpatch.diffs.first.text = bigpatch.diffs.first.text!! - .substring(diff_text.length) - } - } - } - // Compute the head context for the next patch. - precontext = diff_text2(patch.diffs) - precontext = precontext.substring( - max( - 0.0, (precontext.length - - Patch_Margin).toDouble() - ).toInt() - ) - // Append the end context for this patch. - if (diff_text1(bigpatch.diffs).length > Patch_Margin) { - postcontext = diff_text1(bigpatch.diffs).substring(0, Patch_Margin.toInt()) - } else { - postcontext = diff_text1(bigpatch.diffs) - } - if (postcontext.length != 0) { - patch.length1 += postcontext.length - patch.length2 += postcontext.length - if ((!patch.diffs.isEmpty() - && patch.diffs.last.operation == Operation.EQUAL) - ) { - patch.diffs.last.text += postcontext - } else { - patch.diffs.add(Diff(Operation.EQUAL, postcontext)) - } - } - if (!empty) { - pointer.add(patch) - } - } - bigpatch = if (pointer.hasNext()) pointer.next() else null - } + var operation: Operation?, + /** + * The text associated with this diff operation. + */ + var text: String? + ) { + /** + * Display a human-readable version of this Diff. + * @return text version. + */ + override fun toString(): String { + val prettyText = text!!.replace('\n', '\u00b6') + return "Diff(" + this.operation + ",\"" + prettyText + "\")" } /** - * Take a list of patches and return a textual representation. - * @param patches List of Patch objects. - * @return Text representation of patches. + * Create a numeric hash value for a Diff. + * This function is not used by DMP. + * @return Hash value. */ - fun patch_toText(patches: List): String { - val text = StringBuilder() - for (aPatch: Patch? in patches) { - text.append(aPatch) - } - return text.toString() + override fun hashCode(): Int { + val prime = 31 + var result = if ((operation == null)) 0 else operation.hashCode() + result += prime * (if ((text == null)) 0 else text.hashCode()) + return result } /** - * Parse a textual representation of patches and return a List of Patch - * objects. - * @param textline Text representation of patches. - * @return List of Patch objects. - * @throws IllegalArgumentException If invalid input. + * Is this Diff equivalent to another Diff? + * @param obj Another Diff to compare against. + * @return true or false. */ - @Throws(IllegalArgumentException::class) - fun patch_fromText(textline: String): List { - val patches: MutableList = LinkedList() - if (textline.length == 0) { - return patches - } - val textList = Arrays.asList(*textline.split("\n".toRegex()).dropLastWhile { it.isEmpty() } - .toTypedArray()) - val text = LinkedList(textList) - var patch: Patch - val patchHeader = Pattern.compile("^@@ -(\\d+),?(\\d*) \\+(\\d+),?(\\d*) @@$") - var m: Matcher - var sign: Char - var line: String - while (!text.isEmpty()) { - m = patchHeader.matcher(text.first) - if (!m.matches()) { - throw IllegalArgumentException( - "Invalid patch string: " + text.first - ) - } - patch = Patch() - patches.add(patch) - patch.start1 = m.group(1).toInt() - if (m.group(2).length == 0) { - patch.start1-- - patch.length1 = 1 - } else if ((m.group(2) == "0")) { - patch.length1 = 0 - } else { - patch.start1-- - patch.length1 = m.group(2).toInt() - } - - patch.start2 = m.group(3).toInt() - if (m.group(4).length == 0) { - patch.start2-- - patch.length2 = 1 - } else if ((m.group(4) == "0")) { - patch.length2 = 0 - } else { - patch.start2-- - patch.length2 = m.group(4).toInt() - } - text.removeFirst() - - while (!text.isEmpty()) { - try { - sign = text.first[0] - } catch (e: IndexOutOfBoundsException) { - // Blank line? Whatever. - text.removeFirst() - continue - } - line = text.first.substring(1) - line = line.replace("+", "%2B") // decode would change all "+" to " " - try { - line = URLDecoder.decode(line, "UTF-8") - } catch (e: UnsupportedEncodingException) { - // Not likely on modern system. - throw Error("This system does not support UTF-8.", e) - } catch (e: IllegalArgumentException) { - // Malformed URI sequence. - throw IllegalArgumentException( - "Illegal escape in patch_fromText: $line", e - ) - } - if (sign == '-') { - // Deletion. - patch.diffs.add(Diff(Operation.DELETE, line)) - } else if (sign == '+') { - // Insertion. - patch.diffs.add(Diff(Operation.INSERT, line)) - } else if (sign == ' ') { - // Minor equality. - patch.diffs.add(Diff(Operation.EQUAL, line)) - } else if (sign == '@') { - // Start of next patch. - break - } else { - // WTF? - throw IllegalArgumentException( - "Invalid patch mode '$sign' in: $line" - ) - } - text.removeFirst() - } - } - return patches + override fun equals(obj: Any?): Boolean { + if (this === obj) { + return true + } + if (obj == null) { + return false + } + if (javaClass != obj.javaClass) { + return false + } + val other = obj as Diff + if (operation != other.operation) { + return false + } + if (text == null) { + if (other.text != null) { + return false + } + } else if (text != other.text) { + return false + } + return true } + } + /** + * Class representing one patch operation. + */ + class Patch { + var diffs: LinkedList + var start1: Int = 0 + var start2: Int = 0 + var length1: Int = 0 + var length2: Int = 0 + /** - * Class representing one diff operation. + * Constructor. Initializes with an empty list of diffs. */ - class Diff// Construct a diff with the specified operation and text. - /** - * Constructor. Initializes the diff with the provided values. - * @param operation One of INSERT, DELETE or EQUAL. - * @param text The text being applied. - */( - /** - * One of: INSERT, DELETE or EQUAL. - */ - var operation: Operation?, - /** - * The text associated with this diff operation. - */ - var text: String? - ) { - /** - * Display a human-readable version of this Diff. - * @return text version. - */ - override fun toString(): String { - val prettyText = text!!.replace('\n', '\u00b6') - return "Diff(" + this.operation + ",\"" + prettyText + "\")" - } - - /** - * Create a numeric hash value for a Diff. - * This function is not used by DMP. - * @return Hash value. - */ - override fun hashCode(): Int { - val prime = 31 - var result = if ((operation == null)) 0 else operation.hashCode() - result += prime * (if ((text == null)) 0 else text.hashCode()) - return result - } - - /** - * Is this Diff equivalent to another Diff? - * @param obj Another Diff to compare against. - * @return true or false. - */ - override fun equals(obj: Any?): Boolean { - if (this === obj) { - return true - } - if (obj == null) { - return false - } - if (javaClass != obj.javaClass) { - return false - } - val other = obj as Diff - if (operation != other.operation) { - return false - } - if (text == null) { - if (other.text != null) { - return false - } - } else if (text != other.text) { - return false - } - return true - } + init { + this.diffs = LinkedList() } - /** - * Class representing one patch operation. + * Emulate GNU diff's format. + * Header: @@ -382,8 +481,9 @@ + * Indices are printed as 1-based, not 0-based. + * @return The GNU diff string. */ - class Patch { - var diffs: LinkedList - var start1: Int = 0 - var start2: Int = 0 - var length1: Int = 0 - var length2: Int = 0 - - /** - * Constructor. Initializes with an empty list of diffs. - */ - init { - this.diffs = LinkedList() - } - - /** - * Emulate GNU diff's format. - * Header: @@ -382,8 +481,9 @@ - * Indices are printed as 1-based, not 0-based. - * @return The GNU diff string. - */ - override fun toString(): String { - val coords1: String - val coords2: String - if (this.length1 == 0) { - coords1 = start1.toString() + ",0" - } else if (this.length1 == 1) { - coords1 = (this.start1 + 1).toString() - } else { - coords1 = (this.start1 + 1).toString() + "," + this.length1 - } - if (this.length2 == 0) { - coords2 = start2.toString() + ",0" - } else if (this.length2 == 1) { - coords2 = (this.start2 + 1).toString() - } else { - coords2 = (this.start2 + 1).toString() + "," + this.length2 - } - val text = StringBuilder() - text.append("@@ -").append(coords1).append(" +").append(coords2) - .append(" @@\n") - // Escape the body of the patch with %xx notation. - for (aDiff: Diff in this.diffs) { - when (aDiff.operation) { - Operation.INSERT -> text.append('+') - Operation.DELETE -> text.append('-') - Operation.EQUAL -> text.append(' ') - null -> TODO() - } - try { - text.append(URLEncoder.encode(aDiff.text, "UTF-8").replace('+', ' ')) - .append("\n") - } catch (e: UnsupportedEncodingException) { - // Not likely on modern system. - throw Error("This system does not support UTF-8.", e) - } - } - return unescapeForEncodeUriCompatability(text.toString()) - } + override fun toString(): String { + val coords1: String + val coords2: String + if (this.length1 == 0) { + coords1 = start1.toString() + ",0" + } else if (this.length1 == 1) { + coords1 = (this.start1 + 1).toString() + } else { + coords1 = (this.start1 + 1).toString() + "," + this.length1 + } + if (this.length2 == 0) { + coords2 = start2.toString() + ",0" + } else if (this.length2 == 1) { + coords2 = (this.start2 + 1).toString() + } else { + coords2 = (this.start2 + 1).toString() + "," + this.length2 + } + val text = StringBuilder() + text.append("@@ -").append(coords1).append(" +").append(coords2) + .append(" @@\n") + // Escape the body of the patch with %xx notation. + for (aDiff: Diff in this.diffs) { + when (aDiff.operation) { + Operation.INSERT -> text.append('+') + Operation.DELETE -> text.append('-') + Operation.EQUAL -> text.append(' ') + null -> TODO() + } + try { + text.append(URLEncoder.encode(aDiff.text, "UTF-8").replace('+', ' ')) + .append("\n") + } catch (e: UnsupportedEncodingException) { + // Not likely on modern system. + throw Error("This system does not support UTF-8.", e) + } + } + return unescapeForEncodeUriCompatability(text.toString()) } + } - companion object : DiffMatchPatch() { - /** - * Unescape selected chars for compatability with JavaScript's encodeURI. - * In speed critical applications this could be dropped since the - * receiving application will certainly decode these fine. - * Note that this function is case-sensitive. Thus "%3f" would not be - * unescaped. But this is ok because it is only called with the output of - * URLEncoder.encode which returns uppercase hex. - * - * Example: "%3F" -> "?", "%24" -> "$", etc. - * - * @param str The string to escape. - * @return The escaped string. - */ - private fun unescapeForEncodeUriCompatability(str: String): String { - return str.replace("%21", "!").replace("%7E", "~") - .replace("%27", "'").replace("%28", "(").replace("%29", ")") - .replace("%3B", ";").replace("%2F", "/").replace("%3F", "?") - .replace("%3A", ":").replace("%40", "@").replace("%26", "&") - .replace("%3D", "=").replace("%2B", "+").replace("%24", "$") - .replace("%2C", ",").replace("%23", "#") - } + companion object : DiffMatchPatch() { + /** + * Unescape selected chars for compatability with JavaScript's encodeURI. + * In speed critical applications this could be dropped since the + * receiving application will certainly decode these fine. + * Note that this function is case-sensitive. Thus "%3f" would not be + * unescaped. But this is ok because it is only called with the output of + * URLEncoder.encode which returns uppercase hex. + * + * Example: "%3F" -> "?", "%24" -> "$", etc. + * + * @param str The string to escape. + * @return The escaped string. + */ + private fun unescapeForEncodeUriCompatability(str: String): String { + return str.replace("%21", "!").replace("%7E", "~") + .replace("%27", "'").replace("%28", "(").replace("%29", ")") + .replace("%3B", ";").replace("%2F", "/").replace("%3F", "?") + .replace("%3A", ":").replace("%40", "@").replace("%26", "&") + .replace("%3D", "=").replace("%2B", "+").replace("%24", "$") + .replace("%2C", ",").replace("%23", "#") } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/DiffUtil.kt b/webui/src/main/kotlin/com/simiacryptus/diff/DiffUtil.kt index bcd86141..3f4fb3ad 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/DiffUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/DiffUtil.kt @@ -4,161 +4,161 @@ import com.simiacryptus.diff.PatchLineType.* import org.slf4j.LoggerFactory enum class PatchLineType { - Added, Deleted, Unchanged + Added, Deleted, Unchanged } data class PatchLine( - val type: PatchLineType, - val lineNumber: Int, - val line: String, - val compareText: String = line.trim(), + val type: PatchLineType, + val lineNumber: Int, + val line: String, + val compareText: String = line.trim(), ) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false - other as PatchLine + other as PatchLine - return compareText == other.compareText - } + return compareText == other.compareText + } - override fun hashCode(): Int { - return compareText.hashCode() - } + override fun hashCode(): Int { + return compareText.hashCode() + } } object DiffUtil { - // ... (previous code remains unchanged) - - private val log = LoggerFactory.getLogger(DiffUtil::class.java) - - /** - * Generates a list of DiffResult representing the differences between two lists of strings. - * This function compares the original and modified texts line by line and categorizes each line as added, deleted, or unchanged. - * - * @param original The original list of strings. - * @param modified The modified list of strings. - * @return A list of DiffResult indicating the differences. - */ - fun generateDiff(original: List, modified: List): List { - val originalLines = original.mapIndexed { i, v -> PatchLine(Unchanged, i, v.trim()) } - val modifiedLines = modified.mapIndexed { i, v -> PatchLine(Unchanged, i, v.trim()) } - val patchLines = mutableListOf() - var i = 0 - var j = 0 - - log.debug("Starting diff generation. Original size: ${original.size}, Modified size: ${modified.size}") - - while (i < originalLines.size && j < modifiedLines.size) { - val originalLine = originalLines[i] - val modifiedLine = modifiedLines[j] - - log.trace("Comparing lines - Original: $originalLine, Modified: $modifiedLine") - if (originalLine == modifiedLine) { - patchLines.add(PatchLine(Unchanged, originalLine.lineNumber, original[i])) - i++ - j++ - } else { - val originalIndex = originalLines.subList(i, originalLines.size).indexOf(modifiedLine).let { if (it == -1) null else it + i } - val modifiedIndex = modifiedLines.subList(j, modifiedLines.size).indexOf(originalLine).let { if (it == -1) null else it + j } - log.debug("Mismatch found. Original index: $originalIndex, Modified index: $modifiedIndex") - - if (originalIndex != null && modifiedIndex != null) { - log.debug("Both indices found. Choosing shorter path.") - if (originalIndex - i < modifiedIndex - j) { - while (i < originalIndex) { - patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) - i++ - } - } else { - while (j < modifiedIndex) { - patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) - j++ - } - } - } else if (originalIndex != null) { - log.debug("Original index found. Deleting lines until match.") - while (i < originalIndex) { - patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) - i++ - } - } else if (modifiedIndex != null) { - log.debug("Modified index found. Adding lines until match.") - while (j < modifiedIndex) { - patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) - j++ - } - } else { - patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) - patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) - i++ - j++ - } + // ... (previous code remains unchanged) + + private val log = LoggerFactory.getLogger(DiffUtil::class.java) + + /** + * Generates a list of DiffResult representing the differences between two lists of strings. + * This function compares the original and modified texts line by line and categorizes each line as added, deleted, or unchanged. + * + * @param original The original list of strings. + * @param modified The modified list of strings. + * @return A list of DiffResult indicating the differences. + */ + fun generateDiff(original: List, modified: List): List { + val originalLines = original.mapIndexed { i, v -> PatchLine(Unchanged, i, v.trim()) } + val modifiedLines = modified.mapIndexed { i, v -> PatchLine(Unchanged, i, v.trim()) } + val patchLines = mutableListOf() + var i = 0 + var j = 0 + + log.debug("Starting diff generation. Original size: ${original.size}, Modified size: ${modified.size}") + + while (i < originalLines.size && j < modifiedLines.size) { + val originalLine = originalLines[i] + val modifiedLine = modifiedLines[j] + + log.trace("Comparing lines - Original: $originalLine, Modified: $modifiedLine") + if (originalLine == modifiedLine) { + patchLines.add(PatchLine(Unchanged, originalLine.lineNumber, original[i])) + i++ + j++ + } else { + val originalIndex = originalLines.subList(i, originalLines.size).indexOf(modifiedLine).let { if (it == -1) null else it + i } + val modifiedIndex = modifiedLines.subList(j, modifiedLines.size).indexOf(originalLine).let { if (it == -1) null else it + j } + log.debug("Mismatch found. Original index: $originalIndex, Modified index: $modifiedIndex") + + if (originalIndex != null && modifiedIndex != null) { + log.debug("Both indices found. Choosing shorter path.") + if (originalIndex - i < modifiedIndex - j) { + while (i < originalIndex) { + patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) + i++ } - } - - log.debug("Processing remaining lines. Original: ${originalLines.size - i}, Modified: ${modifiedLines.size - j}") - while (i < originalLines.size) { + } else { + while (j < modifiedIndex) { + patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) + j++ + } + } + } else if (originalIndex != null) { + log.debug("Original index found. Deleting lines until match.") + while (i < originalIndex) { patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) i++ - } - - while (j < modifiedLines.size) { + } + } else if (modifiedIndex != null) { + log.debug("Modified index found. Adding lines until match.") + while (j < modifiedIndex) { patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) j++ + } + } else { + patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) + patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) + i++ + j++ } + } + } - log.info("Diff generation completed. Total patch lines: ${patchLines.size}") - - return patchLines + log.debug("Processing remaining lines. Original: ${originalLines.size - i}, Modified: ${modifiedLines.size - j}") + while (i < originalLines.size) { + patchLines.add(PatchLine(Deleted, originalLines[i].lineNumber, original[i])) + i++ } - /** - * Formats the list of DiffResult into a human-readable string representation. - * This function processes each diff result to format added, deleted, and unchanged lines appropriately, - * including context lines and markers for easier reading. - * - * @param patchLines The list of DiffResult to format. - * @param contextLines The number of context lines to include around changes. - * @return A formatted string representing the diff. - */ - fun formatDiff(patchLines: List, contextLines: Int = 3): String { - val formattedLines = mutableListOf() - var lastPrintedLine = -1 - - log.debug("Starting diff formatting. Total lines: ${patchLines.size}, Context lines: $contextLines") - - patchLines.forEachIndexed { index, lineDiff -> - if (lineDiff.type != Unchanged || - (index > 0 && patchLines[index - 1].type != Unchanged) || - (index < patchLines.size - 1 && patchLines[index + 1].type != Unchanged) - ) { - - // Print context lines before the change - val contextStart = maxOf(lastPrintedLine + 1, index - contextLines) - for (i in contextStart until index) { - if (i > lastPrintedLine) { - formattedLines.add(" ${patchLines[i].line}") - lastPrintedLine = i - } - } - - // Print the change - val prefix = when (lineDiff.type) { - Added -> "+ " - Deleted -> "- " - Unchanged -> " " - } - formattedLines.add("$prefix${lineDiff.line}") - lastPrintedLine = index + while (j < modifiedLines.size) { + patchLines.add(PatchLine(Added, modifiedLines[j].lineNumber, modified[j])) + j++ + } - } + log.info("Diff generation completed. Total patch lines: ${patchLines.size}") + + return patchLines + } + + /** + * Formats the list of DiffResult into a human-readable string representation. + * This function processes each diff result to format added, deleted, and unchanged lines appropriately, + * including context lines and markers for easier reading. + * + * @param patchLines The list of DiffResult to format. + * @param contextLines The number of context lines to include around changes. + * @return A formatted string representing the diff. + */ + fun formatDiff(patchLines: List, contextLines: Int = 3): String { + val formattedLines = mutableListOf() + var lastPrintedLine = -1 + + log.debug("Starting diff formatting. Total lines: ${patchLines.size}, Context lines: $contextLines") + + patchLines.forEachIndexed { index, lineDiff -> + if (lineDiff.type != Unchanged || + (index > 0 && patchLines[index - 1].type != Unchanged) || + (index < patchLines.size - 1 && patchLines[index + 1].type != Unchanged) + ) { + + // Print context lines before the change + val contextStart = maxOf(lastPrintedLine + 1, index - contextLines) + for (i in contextStart until index) { + if (i > lastPrintedLine) { + formattedLines.add(" ${patchLines[i].line}") + lastPrintedLine = i + } } - log.info("Diff formatting completed. Total formatted lines: ${formattedLines.size}") + // Print the change + val prefix = when (lineDiff.type) { + Added -> "+ " + Deleted -> "- " + Unchanged -> " " + } + formattedLines.add("$prefix${lineDiff.line}") + lastPrintedLine = index - val formattedDiff = formattedLines.joinToString("\n") - log.debug("Formatted diff:\n$formattedDiff") - return formattedDiff + } } + + log.info("Diff formatting completed. Total formatted lines: ${formattedLines.size}") + + val formattedDiff = formattedLines.joinToString("\n") + log.debug("Formatted diff:\n$formattedDiff") + return formattedDiff + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/FileValidationUtils.kt b/webui/src/main/kotlin/com/simiacryptus/diff/FileValidationUtils.kt index 580c6d99..c3ed3897 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/FileValidationUtils.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/FileValidationUtils.kt @@ -6,171 +6,172 @@ import java.util.* import kotlin.io.path.name class FileValidationUtils { - companion object { - fun isCurlyBalanced(code: String): Boolean { - var count = 0 - for (char in code) { - when (char) { - '{' -> count++ - '}' -> count-- - } - if (count < 0) return false - } - return count == 0 + companion object { + fun isCurlyBalanced(code: String): Boolean { + var count = 0 + for (char in code) { + when (char) { + '{' -> count++ + '}' -> count-- } + if (count < 0) return false + } + return count == 0 + } - fun isSingleQuoteBalanced(code: String): Boolean { - var count = 0 - var escaped = false - for (char in code) { - when { - char == '\\' -> escaped = !escaped - char == '\'' && !escaped -> count++ - else -> escaped = false - } - } - return count % 2 == 0 + fun isSingleQuoteBalanced(code: String): Boolean { + var count = 0 + var escaped = false + for (char in code) { + when { + char == '\\' -> escaped = !escaped + char == '\'' && !escaped -> count++ + else -> escaped = false } + } + return count % 2 == 0 + } - fun isSquareBalanced(code: String): Boolean { - var count = 0 - for (char in code) { - when (char) { - '[' -> count++ - ']' -> count-- - } - if (count < 0) return false - } - return count == 0 + fun isSquareBalanced(code: String): Boolean { + var count = 0 + for (char in code) { + when (char) { + '[' -> count++ + ']' -> count-- } + if (count < 0) return false + } + return count == 0 + } - fun isParenthesisBalanced(code: String): Boolean { - var count = 0 - for (char in code) { - when (char) { - '(' -> count++ - ')' -> count-- - } - if (count < 0) return false - } - return count == 0 + fun isParenthesisBalanced(code: String): Boolean { + var count = 0 + for (char in code) { + when (char) { + '(' -> count++ + ')' -> count-- } + if (count < 0) return false + } + return count == 0 + } - fun isQuoteBalanced(code: String): Boolean { - var count = 0 - var escaped = false - for (char in code) { - when { - char == '\\' -> escaped = !escaped - char == '"' && !escaped -> count++ - else -> escaped = false - } - } - return count % 2 == 0 + fun isQuoteBalanced(code: String): Boolean { + var count = 0 + var escaped = false + for (char in code) { + when { + char == '\\' -> escaped = !escaped + char == '"' && !escaped -> count++ + else -> escaped = false } + } + return count % 2 == 0 + } - fun filteredWalk( - file: File, - maxFilesPerDir : Int = 20, - fn: (File) -> Boolean - ): List { - val result = mutableListOf() - if (fn(file)) { - if (file.isDirectory) { - file.listFiles()?.take(maxFilesPerDir)?.forEach { child -> - result.addAll(filteredWalk(child, maxFilesPerDir, fn)) - } - } else { - result.add(file) - } - } - return result + fun filteredWalk( + file: File, + maxFilesPerDir: Int = 20, + fn: (File) -> Boolean + ): List { + val result = mutableListOf() + if (fn(file)) { + if (file.isDirectory) { + file.listFiles()?.take(maxFilesPerDir)?.forEach { child -> + result.addAll(filteredWalk(child, maxFilesPerDir, fn)) + } + } else { + result.add(file) } + } + return result + } - fun isLLMIncludableFile(file: File): Boolean { - return when { - !file.exists() -> false - file.isDirectory -> false - file.name.startsWith(".") -> false - file.name.endsWith(".data") -> true - file.length() > (256 * 1024) -> false - isGitignore(file.toPath()) -> false - file.extension.lowercase(Locale.getDefault()) in setOf( - "jar", - "zip", - "class", - "png", - "jpg", - "jpeg", - "gif", - "ico", - "stl" - ) -> false + fun isLLMIncludableFile(file: File): Boolean { + return when { + !file.exists() -> false + file.isDirectory -> false + file.name.startsWith(".") -> false + file.name.endsWith(".data") -> true + file.length() > (256 * 1024) -> false + isGitignore(file.toPath()) -> false + file.extension.lowercase(Locale.getDefault()) in setOf( + "jar", + "zip", + "class", + "png", + "jpg", + "jpeg", + "gif", + "ico", + "stl" + ) -> false - else -> true - } - } + else -> true + } + } - fun expandFileList(vararg data: File): Array { - return data.flatMap { - (when { - it.name.startsWith(".") -> arrayOf() - it.name.endsWith(".data") -> arrayOf(it) - isGitignore(it.toPath()) -> arrayOf() - it.length() > 1e6 -> arrayOf() - it.extension.lowercase(Locale.getDefault()) in - setOf("jar", "zip", "class", "png", "jpg", "jpeg", "gif", "ico") -> arrayOf() - it.isDirectory -> expandFileList(*it.listFiles() ?: arrayOf()) - else -> arrayOf(it) - }).toList() - }.toTypedArray() - } + fun expandFileList(vararg data: File): Array { + return data.flatMap { + (when { + it.name.startsWith(".") -> arrayOf() + it.name.endsWith(".data") -> arrayOf(it) + isGitignore(it.toPath()) -> arrayOf() + it.length() > 1e6 -> arrayOf() + it.extension.lowercase(Locale.getDefault()) in + setOf("jar", "zip", "class", "png", "jpg", "jpeg", "gif", "ico") -> arrayOf() - fun isGitignore(path: Path): Boolean { - when { - path.name == "node_modules" -> return true - path.name == "target" -> return true - path.name == "build" -> return true - path.name.startsWith(".") -> return true - } - var currentDir = path.toFile().parentFile - currentDir ?: return false - while (!currentDir.resolve(".git").exists()) { - currentDir.resolve(".gitignore").let { - if (it.exists()) { - val gitignore = it.readText() - if (gitignore.split("\n").any { line -> - try { - if (line.trim().isEmpty()) return@any false - if (line.startsWith("#")) return@any false - val pattern = - line.trim().trimStart('/').trimEnd('/') - .replace(".", "\\.").replace("*", ".*") - return@any path.fileName.toString().trimEnd('/').matches(Regex(pattern)) - } catch (e: Throwable) { - return@any false - } - }) return true - } - } - currentDir = currentDir.parentFile ?: return false - } - currentDir.resolve(".gitignore").let { - if (it.exists()) { - val gitignore = it.readText() - if (gitignore.split("\n").any { line -> - val pattern = line.trim().trimEnd('/').replace(".", "\\.").replace("*", ".*") - line.trim().isNotEmpty() - && !line.startsWith("#") - && path.fileName.toString().trimEnd('/').matches(Regex(pattern)) - }) { - return true - } + it.isDirectory -> expandFileList(*it.listFiles() ?: arrayOf()) + else -> arrayOf(it) + }).toList() + }.toTypedArray() + } + + fun isGitignore(path: Path): Boolean { + when { + path.name == "node_modules" -> return true + path.name == "target" -> return true + path.name == "build" -> return true + path.name.startsWith(".") -> return true + } + var currentDir = path.toFile().parentFile + currentDir ?: return false + while (!currentDir.resolve(".git").exists()) { + currentDir.resolve(".gitignore").let { + if (it.exists()) { + val gitignore = it.readText() + if (gitignore.split("\n").any { line -> + try { + if (line.trim().isEmpty()) return@any false + if (line.startsWith("#")) return@any false + val pattern = + line.trim().trimStart('/').trimEnd('/') + .replace(".", "\\.").replace("*", ".*") + return@any path.fileName.toString().trimEnd('/').matches(Regex(pattern)) + } catch (e: Throwable) { + return@any false } - } - return false + }) return true + } } - + currentDir = currentDir.parentFile ?: return false + } + currentDir.resolve(".gitignore").let { + if (it.exists()) { + val gitignore = it.readText() + if (gitignore.split("\n").any { line -> + val pattern = line.trim().trimEnd('/').replace(".", "\\.").replace("*", ".*") + line.trim().isNotEmpty() + && !line.startsWith("#") + && path.fileName.toString().trimEnd('/').matches(Regex(pattern)) + }) { + return true + } + } + } + return false } + } + } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt b/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt index 2cdcb4d2..224c4a43 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/IterativePatchUtil.kt @@ -10,749 +10,749 @@ import kotlin.math.max import kotlin.math.min object IterativePatchUtil { - enum class LineType { CONTEXT, ADD, DELETE } - - // Tracks the nesting depth of different bracket types - data class LineMetrics( - var parenthesesDepth: Int = 0, - var squareBracketsDepth: Int = 0, - var curlyBracesDepth: Int = 0 - ) - - // Represents a single line in the source or patch text - data class LineRecord( - val index: Int, - val line: String?, - var previousLine: LineRecord? = null, - var nextLine: LineRecord? = null, - var matchingLine: LineRecord? = null, - var type: LineType = CONTEXT, - var metrics: LineMetrics = LineMetrics() - ) { - override fun toString(): String { - val sb = StringBuilder() - sb.append("${index.toString().padStart(5, ' ')}: ") - when (type) { - CONTEXT -> sb.append(" ") - ADD -> sb.append("+") - DELETE -> sb.append("-") - } - sb.append(" ") - sb.append(line) - sb.append(" (${metrics.parenthesesDepth})[${metrics.squareBracketsDepth}]{${metrics.curlyBracesDepth}}") - return sb.toString() - } - - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false - other as LineRecord - if (index != other.index) return false - if (line != other.line) return false - return true - } - - override fun hashCode(): Int { - var result = index - result = 31 * result + (line?.hashCode() ?: 0) - return result - } + enum class LineType { CONTEXT, ADD, DELETE } + + // Tracks the nesting depth of different bracket types + data class LineMetrics( + var parenthesesDepth: Int = 0, + var squareBracketsDepth: Int = 0, + var curlyBracesDepth: Int = 0 + ) + + // Represents a single line in the source or patch text + data class LineRecord( + val index: Int, + val line: String?, + var previousLine: LineRecord? = null, + var nextLine: LineRecord? = null, + var matchingLine: LineRecord? = null, + var type: LineType = CONTEXT, + var metrics: LineMetrics = LineMetrics() + ) { + override fun toString(): String { + val sb = StringBuilder() + sb.append("${index.toString().padStart(5, ' ')}: ") + when (type) { + CONTEXT -> sb.append(" ") + ADD -> sb.append("+") + DELETE -> sb.append("-") + } + sb.append(" ") + sb.append(line) + sb.append(" (${metrics.parenthesesDepth})[${metrics.squareBracketsDepth}]{${metrics.curlyBracesDepth}}") + return sb.toString() + } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as LineRecord + if (index != other.index) return false + if (line != other.line) return false + return true } - fun generatePatch(oldCode: String, newCode: String): String { - log.info("Starting patch generation process") - val sourceLines = parseLines(oldCode) - val newLines = parseLines(newCode) - link(sourceLines, newLines, null) - log.debug("Parsed and linked source lines: ${sourceLines.size}, new lines: ${newLines.size}") - markMovedLines(newLines) - val longDiff = newToPatch(newLines) - val shortDiff = truncateContext(longDiff).toMutableList() - fixPatchLineOrder(shortDiff) - annihilateNoopLinePairs(shortDiff) - log.debug("Generated diff with ${shortDiff.size} lines after processing") - val patch = StringBuilder() - // Generate the patch text - shortDiff.forEach { line -> - when (line.type) { - CONTEXT -> patch.append(" ${line.line}\n") - ADD -> patch.append("+ ${line.line}\n") - DELETE -> patch.append("- ${line.line}\n") - } - } - log.info("Patch generation completed") - return patch.toString().trimEnd() + override fun hashCode(): Int { + var result = index + result = 31 * result + (line?.hashCode() ?: 0) + return result } - /** - * Applies a patch to the given source text. - * @param source The original text. - * @param patch The patch to apply. - * @return The text after the patch has been applied. - */ - fun applyPatch(source: String, patch: String): String { - log.info("Starting patch application process") - // Parse the source and patch texts into lists of line records - val sourceLines = parseLines(source) - var patchLines = parsePatchLines(patch, sourceLines) - log.debug("Parsed source lines: ${sourceLines.size}, initial patch lines: ${patchLines.size}") - link(sourceLines, patchLines, LevenshteinDistance()) - - // Filter out empty lines in the patch - patchLines = patchLines.filter { it.line?.let { normalizeLine(it).isEmpty() } == false } - log.debug("Filtered patch lines: ${patchLines.size}") - log.info("Generating patched text") - - val result = generatePatchedText(sourceLines, patchLines) - val generatePatchedTextUsingLinks = result.joinToString("\n").trim() - log.info("Patch application completed") - - return generatePatchedTextUsingLinks - } - private fun annihilateNoopLinePairs(diff: MutableList) { - log.debug("Starting annihilation of no-op line pairs") - val toRemove = mutableListOf>() - var i = 0 - while (i < diff.size - 1) { - if (diff[i].type == DELETE) { - var j = i + 1 - while (j < diff.size && diff[j].type != CONTEXT) { - if (diff[j].type == ADD && - normalizeLine(diff[i].line ?: "") == normalizeLine(diff[j].line ?: "") - ) { - toRemove.add(Pair(i, j)) - break - } - j++ - } - } - i++ + } + + fun generatePatch(oldCode: String, newCode: String): String { + log.info("Starting patch generation process") + val sourceLines = parseLines(oldCode) + val newLines = parseLines(newCode) + link(sourceLines, newLines, null) + log.debug("Parsed and linked source lines: ${sourceLines.size}, new lines: ${newLines.size}") + markMovedLines(newLines) + val longDiff = newToPatch(newLines) + val shortDiff = truncateContext(longDiff).toMutableList() + fixPatchLineOrder(shortDiff) + annihilateNoopLinePairs(shortDiff) + log.debug("Generated diff with ${shortDiff.size} lines after processing") + val patch = StringBuilder() + // Generate the patch text + shortDiff.forEach { line -> + when (line.type) { + CONTEXT -> patch.append(" ${line.line}\n") + ADD -> patch.append("+ ${line.line}\n") + DELETE -> patch.append("- ${line.line}\n") + } + } + log.info("Patch generation completed") + return patch.toString().trimEnd() + } + + /** + * Applies a patch to the given source text. + * @param source The original text. + * @param patch The patch to apply. + * @return The text after the patch has been applied. + */ + fun applyPatch(source: String, patch: String): String { + log.info("Starting patch application process") + // Parse the source and patch texts into lists of line records + val sourceLines = parseLines(source) + var patchLines = parsePatchLines(patch, sourceLines) + log.debug("Parsed source lines: ${sourceLines.size}, initial patch lines: ${patchLines.size}") + link(sourceLines, patchLines, LevenshteinDistance()) + + // Filter out empty lines in the patch + patchLines = patchLines.filter { it.line?.let { normalizeLine(it).isEmpty() } == false } + log.debug("Filtered patch lines: ${patchLines.size}") + log.info("Generating patched text") + + val result = generatePatchedText(sourceLines, patchLines) + val generatePatchedTextUsingLinks = result.joinToString("\n").trim() + log.info("Patch application completed") + + return generatePatchedTextUsingLinks + } + + private fun annihilateNoopLinePairs(diff: MutableList) { + log.debug("Starting annihilation of no-op line pairs") + val toRemove = mutableListOf>() + var i = 0 + while (i < diff.size - 1) { + if (diff[i].type == DELETE) { + var j = i + 1 + while (j < diff.size && diff[j].type != CONTEXT) { + if (diff[j].type == ADD && + normalizeLine(diff[i].line ?: "") == normalizeLine(diff[j].line ?: "") + ) { + toRemove.add(Pair(i, j)) + break + } + j++ } - // Remove the pairs in reverse order to maintain correct indices - toRemove.flatMap { listOf(it.first, it.second) }.distinct().sortedDescending().forEach { diff.removeAt(it) } - log.debug("Removed ${toRemove.size} no-op line pairs") + } + i++ } - - private fun markMovedLines(newLines: List) { - log.debug("Starting to mark moved lines") - // We start with the first line of the new (patched) code - var newLine = newLines.firstOrNull() - var iterationCount = 0 - val maxIterations = newLines.size * 2 // Arbitrary limit to prevent infinite loops - // We'll iterate through all lines of the new code - while (null != newLine) { - try { - // We only process lines that have a matching line in the source code - if (newLine.matchingLine != null) { - // Get the next line in the new code - var nextNewLine = newLine.nextLine ?: break - try { - // Skip any lines that don't have a match or are additions - // This helps us find the next "anchor" point in the new code - while (nextNewLine.matchingLine == null || nextNewLine.type == ADD) { - nextNewLine = nextNewLine.nextLine ?: break - } - if (nextNewLine.matchingLine == null || nextNewLine.type == ADD) break - // Get the corresponding line in the source code - val sourceLine = newLine.matchingLine!! - log.debug("Processing patch line ${newLine.index} with matching source line ${sourceLine.index}") - // Find the next line in the source code - var nextSourceLine = sourceLine.nextLine ?: continue - // Skip any lines in the source that don't have a match or are deletions - // This helps us find the next "anchor" point in the source code - while (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) { - // Safeguard to prevent infinite loop - if (++iterationCount > maxIterations) { - log.error("Exceeded maximum iterations in markMovedLines") - break - } - nextSourceLine = nextSourceLine.nextLine ?: break - } - if (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) break - // If the next matching lines in source and new don't correspond, - // it means there's a moved block of code - while (nextNewLine.matchingLine != nextSourceLine) { - if (nextSourceLine.matchingLine != null) { - // Mark the line in the new code as an addition - nextSourceLine.type = DELETE - // Mark the corresponding line in the source code as a deletion - nextSourceLine.matchingLine!!.type = ADD - log.debug("Marked moved line: Patch[${nextSourceLine.index}] as ADD, Source[${nextSourceLine.matchingLine!!.index}] as DELETE") - } - // Move to the next line in the new code - nextSourceLine = nextSourceLine.nextLine ?: break - // Skip any lines that don't have a match or are additions - while (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) { - nextSourceLine = nextSourceLine.nextLine ?: continue - } - } - } finally { - // Safeguard to prevent infinite loop - if (++iterationCount > maxIterations) { - log.error("Exceeded maximum iterations in markMovedLines") - newLine = nextNewLine - // Move to the next line to process in the outer loop - // newLine = nextNewLine - } - } - } else { - // If the current line doesn't have a match, move to the next one - newLine = newLine.nextLine - } - } catch (e: Exception) { - log.error("Error marking moved lines", e) + // Remove the pairs in reverse order to maintain correct indices + toRemove.flatMap { listOf(it.first, it.second) }.distinct().sortedDescending().forEach { diff.removeAt(it) } + log.debug("Removed ${toRemove.size} no-op line pairs") + } + + private fun markMovedLines(newLines: List) { + log.debug("Starting to mark moved lines") + // We start with the first line of the new (patched) code + var newLine = newLines.firstOrNull() + var iterationCount = 0 + val maxIterations = newLines.size * 2 // Arbitrary limit to prevent infinite loops + // We'll iterate through all lines of the new code + while (null != newLine) { + try { + // We only process lines that have a matching line in the source code + if (newLine.matchingLine != null) { + // Get the next line in the new code + var nextNewLine = newLine.nextLine ?: break + try { + // Skip any lines that don't have a match or are additions + // This helps us find the next "anchor" point in the new code + while (nextNewLine.matchingLine == null || nextNewLine.type == ADD) { + nextNewLine = nextNewLine.nextLine ?: break } + if (nextNewLine.matchingLine == null || nextNewLine.type == ADD) break + // Get the corresponding line in the source code + val sourceLine = newLine.matchingLine!! + log.debug("Processing patch line ${newLine.index} with matching source line ${sourceLine.index}") + // Find the next line in the source code + var nextSourceLine = sourceLine.nextLine ?: continue + // Skip any lines in the source that don't have a match or are deletions + // This helps us find the next "anchor" point in the source code + while (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) { + // Safeguard to prevent infinite loop + if (++iterationCount > maxIterations) { + log.error("Exceeded maximum iterations in markMovedLines") + break + } + nextSourceLine = nextSourceLine.nextLine ?: break + } + if (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) break + // If the next matching lines in source and new don't correspond, + // it means there's a moved block of code + while (nextNewLine.matchingLine != nextSourceLine) { + if (nextSourceLine.matchingLine != null) { + // Mark the line in the new code as an addition + nextSourceLine.type = DELETE + // Mark the corresponding line in the source code as a deletion + nextSourceLine.matchingLine!!.type = ADD + log.debug("Marked moved line: Patch[${nextSourceLine.index}] as ADD, Source[${nextSourceLine.matchingLine!!.index}] as DELETE") + } + // Move to the next line in the new code + nextSourceLine = nextSourceLine.nextLine ?: break + // Skip any lines that don't have a match or are additions + while (nextSourceLine.matchingLine == null || nextSourceLine.type == DELETE) { + nextSourceLine = nextSourceLine.nextLine ?: continue + } + } + } finally { + // Safeguard to prevent infinite loop + if (++iterationCount > maxIterations) { + log.error("Exceeded maximum iterations in markMovedLines") + newLine = nextNewLine + // Move to the next line to process in the outer loop + // newLine = nextNewLine + } + } + } else { + // If the current line doesn't have a match, move to the next one + newLine = newLine.nextLine } - // At this point, we've marked all moved lines in both the source and new code - log.debug("Finished marking moved lines") + } catch (e: Exception) { + log.error("Error marking moved lines", e) + } } + // At this point, we've marked all moved lines in both the source and new code + log.debug("Finished marking moved lines") + } + + private fun newToPatch( + newLines: List + ): MutableList { + val diff = mutableListOf() + log.debug("Starting diff generation") + // Generate raw patch without limited context windows + var newLine = newLines.firstOrNull() + while (newLine != null) { + val sourceLine = newLine.matchingLine + when { + sourceLine == null || newLine.type == ADD -> { + diff.add(LineRecord(newLine.index, newLine.line, type = ADD)) + log.debug("Added ADD line: ${newLine.line}") + } - private fun newToPatch( - newLines: List - ): MutableList { - val diff = mutableListOf() - log.debug("Starting diff generation") - // Generate raw patch without limited context windows - var newLine = newLines.firstOrNull() - while (newLine != null) { - val sourceLine = newLine.matchingLine - when { - sourceLine == null || newLine.type == ADD -> { - diff.add(LineRecord(newLine.index, newLine.line, type = ADD)) - log.debug("Added ADD line: ${newLine.line}") - } - - else -> { - // search for prior, unlinked source lines - var priorSourceLine = sourceLine.previousLine - val lineBuffer = mutableListOf() - while (priorSourceLine != null && (priorSourceLine.matchingLine == null || priorSourceLine.type == DELETE)) { - // Note the deletion of the prior source line - lineBuffer.add(LineRecord(-1, priorSourceLine.line, type = DELETE)) - priorSourceLine = priorSourceLine.previousLine - } - diff.addAll(lineBuffer.reversed()) - diff.add(LineRecord(newLine.index, newLine.line, type = CONTEXT)) - log.debug("Added CONTEXT line: ${sourceLine.line}") - } - } - newLine = newLine.nextLine + else -> { + // search for prior, unlinked source lines + var priorSourceLine = sourceLine.previousLine + val lineBuffer = mutableListOf() + while (priorSourceLine != null && (priorSourceLine.matchingLine == null || priorSourceLine.type == DELETE)) { + // Note the deletion of the prior source line + lineBuffer.add(LineRecord(-1, priorSourceLine.line, type = DELETE)) + priorSourceLine = priorSourceLine.previousLine + } + diff.addAll(lineBuffer.reversed()) + diff.add(LineRecord(newLine.index, newLine.line, type = CONTEXT)) + log.debug("Added CONTEXT line: ${sourceLine.line}") } - log.debug("Generated diff with ${diff.size} lines") - return diff + } + newLine = newLine.nextLine } - - private fun truncateContext(diff: MutableList): MutableList { - val contextSize = 3 // Number of context lines before and after changes - log.debug("Truncating context with size $contextSize") - val truncatedDiff = mutableListOf() - val contextBuffer = mutableListOf() - for (i in diff.indices) { - val line = diff[i] - when { - line.type != CONTEXT -> { - // Start of a change, add buffered context - if (contextSize * 2 < contextBuffer.size) { - if (truncatedDiff.isNotEmpty()) { - truncatedDiff.addAll(contextBuffer.take(contextSize)) - truncatedDiff.add(LineRecord(-1, "...", type = CONTEXT)) - } - truncatedDiff.addAll(contextBuffer.takeLast(contextSize)) - } else { - truncatedDiff.addAll(contextBuffer) - } - contextBuffer.clear() - truncatedDiff.add(line) - } - - else -> { - contextBuffer.add(line) - } + log.debug("Generated diff with ${diff.size} lines") + return diff + } + + private fun truncateContext(diff: MutableList): MutableList { + val contextSize = 3 // Number of context lines before and after changes + log.debug("Truncating context with size $contextSize") + val truncatedDiff = mutableListOf() + val contextBuffer = mutableListOf() + for (i in diff.indices) { + val line = diff[i] + when { + line.type != CONTEXT -> { + // Start of a change, add buffered context + if (contextSize * 2 < contextBuffer.size) { + if (truncatedDiff.isNotEmpty()) { + truncatedDiff.addAll(contextBuffer.take(contextSize)) + truncatedDiff.add(LineRecord(-1, "...", type = CONTEXT)) } - } - if (truncatedDiff.isEmpty()) { - return truncatedDiff - } - if (contextSize < contextBuffer.size) { - truncatedDiff.addAll(contextBuffer.take(contextSize)) - } else { + truncatedDiff.addAll(contextBuffer.takeLast(contextSize)) + } else { truncatedDiff.addAll(contextBuffer) + } + contextBuffer.clear() + truncatedDiff.add(line) } - // Add trailing context after the last change - log.debug("Truncated diff size: ${truncatedDiff.size}") - return truncatedDiff - } - /** - * Normalizes a line by removing all whitespace. - * @param line The line to normalize. - * @return The normalized line. - */ - private fun normalizeLine(line: String): String { - return line.replace(whitespaceRegex, "") + else -> { + contextBuffer.add(line) + } + } } - - private val whitespaceRegex = "\\s".toRegex() - - private fun link( - sourceLines: List, - patchLines: List, - levenshteinDistance: LevenshteinDistance? - ) { - // Step 1: Link all unique lines in the source and patch that match exactly - log.info("Step 1: Linking unique matching lines") - linkUniqueMatchingLines(sourceLines, patchLines) - - // Step 2: Link all exact matches in the source and patch which are adjacent to established links - log.info("Step 2: Linking adjacent matching lines") - linkAdjacentMatchingLines(sourceLines, levenshteinDistance) - log.info("Step 3: Performing subsequence linking") - - subsequenceLinking(sourceLines, patchLines, levenshteinDistance = levenshteinDistance) + if (truncatedDiff.isEmpty()) { + return truncatedDiff } - - private fun subsequenceLinking( - sourceLines: List, - patchLines: List, - depth: Int = 0, - levenshteinDistance: LevenshteinDistance? - ) { - log.debug("Subsequence linking at depth $depth") - if (depth > 10 || sourceLines.isEmpty() || patchLines.isEmpty()) { - return // Base case: prevent excessive recursion - } - val sourceSegment = sourceLines.filter { it.matchingLine == null } - val patchSegment = patchLines.filter { it.matchingLine == null } - if (sourceSegment.isNotEmpty() && patchSegment.isNotEmpty()) { - var matchedLines = linkUniqueMatchingLines(sourceSegment, patchSegment) - matchedLines += linkAdjacentMatchingLines(sourceSegment, levenshteinDistance) - if (matchedLines == 0) { - matchedLines += matchFirstBrackets(sourceSegment, patchSegment) - } - if (matchedLines > 0) { - subsequenceLinking(sourceSegment, patchSegment, depth + 1, levenshteinDistance) - } - log.debug("Matched $matchedLines lines in subsequence linking at depth $depth") - } + if (contextSize < contextBuffer.size) { + truncatedDiff.addAll(contextBuffer.take(contextSize)) + } else { + truncatedDiff.addAll(contextBuffer) } - - private fun generatePatchedText( - sourceLines: List, - patchLines: List, - ): List { - log.debug("Starting to generate patched text") - val patchedText: MutableList = mutableListOf() - val usedPatchLines = mutableSetOf() - var sourceIndex = -1 - var lastMatchedPatchIndex = -1 - while (sourceIndex < sourceLines.size - 1) { - val codeLine = sourceLines[++sourceIndex] - when { - codeLine.matchingLine?.type == DELETE -> { - val patchLine = codeLine.matchingLine!! - log.debug("Deleting line: {}", codeLine) - // Delete the line -- do not add to patched text - usedPatchLines.add(patchLine) - checkAfterForInserts(patchLine, usedPatchLines, patchedText) - lastMatchedPatchIndex = patchLine.index - } - - codeLine.matchingLine != null -> { - val patchLine: LineRecord = codeLine.matchingLine!! - log.debug("Patching line: {} <-> {}", codeLine, patchLine) - checkBeforeForInserts(patchLine, usedPatchLines, patchedText) - usedPatchLines.add(patchLine) - // Use the source line if it matches the patch line (ignoring whitespace) - if (normalizeLine(codeLine.line ?: "") == normalizeLine(patchLine.line ?: "")) { - patchedText.add(codeLine.line ?: "") - } else { - patchedText.add(patchLine.line ?: "") - } - checkAfterForInserts(patchLine, usedPatchLines, patchedText) - lastMatchedPatchIndex = patchLine.index - } - - else -> { - log.debug("Added unmatched source line: {}", codeLine) - patchedText.add(codeLine.line ?: "") - } - - } - } - if (lastMatchedPatchIndex == -1) patchLines.filter { it.type == ADD && !usedPatchLines.contains(it) } - .forEach { line -> - log.debug("Added patch line: {}", line) - patchedText.add(line.line ?: "") - } - log.debug("Generated patched text with ${patchedText.size} lines") - return patchedText + // Add trailing context after the last change + log.debug("Truncated diff size: ${truncatedDiff.size}") + return truncatedDiff + } + + /** + * Normalizes a line by removing all whitespace. + * @param line The line to normalize. + * @return The normalized line. + */ + private fun normalizeLine(line: String): String { + return line.replace(whitespaceRegex, "") + } + + private val whitespaceRegex = "\\s".toRegex() + + private fun link( + sourceLines: List, + patchLines: List, + levenshteinDistance: LevenshteinDistance? + ) { + // Step 1: Link all unique lines in the source and patch that match exactly + log.info("Step 1: Linking unique matching lines") + linkUniqueMatchingLines(sourceLines, patchLines) + + // Step 2: Link all exact matches in the source and patch which are adjacent to established links + log.info("Step 2: Linking adjacent matching lines") + linkAdjacentMatchingLines(sourceLines, levenshteinDistance) + log.info("Step 3: Performing subsequence linking") + + subsequenceLinking(sourceLines, patchLines, levenshteinDistance = levenshteinDistance) + } + + private fun subsequenceLinking( + sourceLines: List, + patchLines: List, + depth: Int = 0, + levenshteinDistance: LevenshteinDistance? + ) { + log.debug("Subsequence linking at depth $depth") + if (depth > 10 || sourceLines.isEmpty() || patchLines.isEmpty()) { + return // Base case: prevent excessive recursion } - - private fun checkBeforeForInserts( - patchLine: LineRecord, - usedPatchLines: MutableSet, - patchedText: MutableList - ): LineRecord? { - val buffer = mutableListOf() - var prevPatchLine = patchLine.previousLine - while (null != prevPatchLine) { - if (prevPatchLine.type != ADD || usedPatchLines.contains(prevPatchLine)) { - break - } - - log.debug("Added unmatched patch line: {}", prevPatchLine) - buffer.add(prevPatchLine.line ?: "") - usedPatchLines.add(prevPatchLine) - prevPatchLine = prevPatchLine.previousLine - } - patchedText.addAll(buffer.reversed()) - return prevPatchLine + val sourceSegment = sourceLines.filter { it.matchingLine == null } + val patchSegment = patchLines.filter { it.matchingLine == null } + if (sourceSegment.isNotEmpty() && patchSegment.isNotEmpty()) { + var matchedLines = linkUniqueMatchingLines(sourceSegment, patchSegment) + matchedLines += linkAdjacentMatchingLines(sourceSegment, levenshteinDistance) + if (matchedLines == 0) { + matchedLines += matchFirstBrackets(sourceSegment, patchSegment) + } + if (matchedLines > 0) { + subsequenceLinking(sourceSegment, patchSegment, depth + 1, levenshteinDistance) + } + log.debug("Matched $matchedLines lines in subsequence linking at depth $depth") } - - private fun checkAfterForInserts( - patchLine: LineRecord, - usedPatchLines: MutableSet, - patchedText: MutableList - ): LineRecord { - var nextPatchLine = patchLine.nextLine - while (null != nextPatchLine) { - while (nextPatchLine != null && ( - normalizeLine(nextPatchLine.line ?: "").isEmpty() || - (nextPatchLine.matchingLine == null && nextPatchLine.type == CONTEXT) - ) - ) { - nextPatchLine = nextPatchLine.nextLine - } - if (nextPatchLine == null) break - if (nextPatchLine.type != ADD) break - if (usedPatchLines.contains(nextPatchLine)) break - log.debug("Added unmatched patch line: {}", nextPatchLine) - patchedText.add(nextPatchLine.line ?: "") - usedPatchLines.add(nextPatchLine) - nextPatchLine = nextPatchLine.nextLine + } + + private fun generatePatchedText( + sourceLines: List, + patchLines: List, + ): List { + log.debug("Starting to generate patched text") + val patchedText: MutableList = mutableListOf() + val usedPatchLines = mutableSetOf() + var sourceIndex = -1 + var lastMatchedPatchIndex = -1 + while (sourceIndex < sourceLines.size - 1) { + val codeLine = sourceLines[++sourceIndex] + when { + codeLine.matchingLine?.type == DELETE -> { + val patchLine = codeLine.matchingLine!! + log.debug("Deleting line: {}", codeLine) + // Delete the line -- do not add to patched text + usedPatchLines.add(patchLine) + checkAfterForInserts(patchLine, usedPatchLines, patchedText) + lastMatchedPatchIndex = patchLine.index } - return nextPatchLine ?: patchLine - } - private fun matchFirstBrackets(sourceLines: List, patchLines: List): Int { - log.debug("Starting to match first brackets") - log.debug("Starting to link unique matching lines") - // Group source lines by their normalized content - val sourceLineMap = sourceLines.filter { - it.line?.lineMetrics() != LineMetrics() - }.groupBy { normalizeLine(it.line!!) } - // Group patch lines by their normalized content, excluding ADD lines - val patchLineMap = patchLines.filter { - it.line?.lineMetrics() != LineMetrics() - }.filter { - when (it.type) { - ADD -> false // ADD lines are not matched to source lines - else -> true - } - }.groupBy { normalizeLine(it.line!!) } - log.debug("Created source and patch line maps") - - // Find intersecting keys (matching lines) and link them - val matched = sourceLineMap.keys.intersect(patchLineMap.keys) - matched.forEach { key -> - val sourceGroup = sourceLineMap[key]!! - val patchGroup = patchLineMap[key]!! - for (i in 0 until min(sourceGroup.size, patchGroup.size)) { - sourceGroup[i].matchingLine = patchGroup[i] - patchGroup[i].matchingLine = sourceGroup[i] - log.debug("Linked matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") - } + codeLine.matchingLine != null -> { + val patchLine: LineRecord = codeLine.matchingLine!! + log.debug("Patching line: {} <-> {}", codeLine, patchLine) + checkBeforeForInserts(patchLine, usedPatchLines, patchedText) + usedPatchLines.add(patchLine) + // Use the source line if it matches the patch line (ignoring whitespace) + if (normalizeLine(codeLine.line ?: "") == normalizeLine(patchLine.line ?: "")) { + patchedText.add(codeLine.line ?: "") + } else { + patchedText.add(patchLine.line ?: "") + } + checkAfterForInserts(patchLine, usedPatchLines, patchedText) + lastMatchedPatchIndex = patchLine.index } - val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } - log.debug("Finished matching first brackets. Matched $matchedCount lines") - return matched.sumOf { sourceLineMap[it]!!.size } - } - /** - * Links lines between the source and the patch that are unique and match exactly. - * @param sourceLines The source lines. - * @param patchLines The patch lines. - */ - private fun linkUniqueMatchingLines(sourceLines: List, patchLines: List): Int { - log.debug("Starting to link unique matching lines. Source lines: ${sourceLines.size}, Patch lines: ${patchLines.size}") - // Group source lines by their normalized content - val sourceLineMap = sourceLines.groupBy { normalizeLine(it.line!!) } - // Group patch lines by their normalized content, excluding ADD lines - val patchLineMap = patchLines.filter { - when (it.type) { - ADD -> false // ADD lines are not matched to source lines - else -> true - } - }.groupBy { normalizeLine(it.line!!) } - log.debug("Created source and patch line maps") + else -> { + log.debug("Added unmatched source line: {}", codeLine) + patchedText.add(codeLine.line ?: "") + } - // Find intersecting keys (matching lines) and link them - val matched = sourceLineMap.keys.intersect(patchLineMap.keys).filter { - sourceLineMap[it]?.size == patchLineMap[it]?.size + } + } + if (lastMatchedPatchIndex == -1) patchLines.filter { it.type == ADD && !usedPatchLines.contains(it) } + .forEach { line -> + log.debug("Added patch line: {}", line) + patchedText.add(line.line ?: "") + } + log.debug("Generated patched text with ${patchedText.size} lines") + return patchedText + } + + private fun checkBeforeForInserts( + patchLine: LineRecord, + usedPatchLines: MutableSet, + patchedText: MutableList + ): LineRecord? { + val buffer = mutableListOf() + var prevPatchLine = patchLine.previousLine + while (null != prevPatchLine) { + if (prevPatchLine.type != ADD || usedPatchLines.contains(prevPatchLine)) { + break + } + + log.debug("Added unmatched patch line: {}", prevPatchLine) + buffer.add(prevPatchLine.line ?: "") + usedPatchLines.add(prevPatchLine) + prevPatchLine = prevPatchLine.previousLine + } + patchedText.addAll(buffer.reversed()) + return prevPatchLine + } + + private fun checkAfterForInserts( + patchLine: LineRecord, + usedPatchLines: MutableSet, + patchedText: MutableList + ): LineRecord { + var nextPatchLine = patchLine.nextLine + while (null != nextPatchLine) { + while (nextPatchLine != null && ( + normalizeLine(nextPatchLine.line ?: "").isEmpty() || + (nextPatchLine.matchingLine == null && nextPatchLine.type == CONTEXT) + ) + ) { + nextPatchLine = nextPatchLine.nextLine + } + if (nextPatchLine == null) break + if (nextPatchLine.type != ADD) break + if (usedPatchLines.contains(nextPatchLine)) break + log.debug("Added unmatched patch line: {}", nextPatchLine) + patchedText.add(nextPatchLine.line ?: "") + usedPatchLines.add(nextPatchLine) + nextPatchLine = nextPatchLine.nextLine + } + return nextPatchLine ?: patchLine + } + + private fun matchFirstBrackets(sourceLines: List, patchLines: List): Int { + log.debug("Starting to match first brackets") + log.debug("Starting to link unique matching lines") + // Group source lines by their normalized content + val sourceLineMap = sourceLines.filter { + it.line?.lineMetrics() != LineMetrics() + }.groupBy { normalizeLine(it.line!!) } + // Group patch lines by their normalized content, excluding ADD lines + val patchLineMap = patchLines.filter { + it.line?.lineMetrics() != LineMetrics() + }.filter { + when (it.type) { + ADD -> false // ADD lines are not matched to source lines + else -> true + } + }.groupBy { normalizeLine(it.line!!) } + log.debug("Created source and patch line maps") + + // Find intersecting keys (matching lines) and link them + val matched = sourceLineMap.keys.intersect(patchLineMap.keys) + matched.forEach { key -> + val sourceGroup = sourceLineMap[key]!! + val patchGroup = patchLineMap[key]!! + for (i in 0 until min(sourceGroup.size, patchGroup.size)) { + sourceGroup[i].matchingLine = patchGroup[i] + patchGroup[i].matchingLine = sourceGroup[i] + log.debug("Linked matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") + } + } + val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } + log.debug("Finished matching first brackets. Matched $matchedCount lines") + return matched.sumOf { sourceLineMap[it]!!.size } + } + + /** + * Links lines between the source and the patch that are unique and match exactly. + * @param sourceLines The source lines. + * @param patchLines The patch lines. + */ + private fun linkUniqueMatchingLines(sourceLines: List, patchLines: List): Int { + log.debug("Starting to link unique matching lines. Source lines: ${sourceLines.size}, Patch lines: ${patchLines.size}") + // Group source lines by their normalized content + val sourceLineMap = sourceLines.groupBy { normalizeLine(it.line!!) } + // Group patch lines by their normalized content, excluding ADD lines + val patchLineMap = patchLines.filter { + when (it.type) { + ADD -> false // ADD lines are not matched to source lines + else -> true + } + }.groupBy { normalizeLine(it.line!!) } + log.debug("Created source and patch line maps") + + // Find intersecting keys (matching lines) and link them + val matched = sourceLineMap.keys.intersect(patchLineMap.keys).filter { + sourceLineMap[it]?.size == patchLineMap[it]?.size + } + matched.forEach { key -> + val sourceGroup = sourceLineMap[key]!! + val patchGroup = patchLineMap[key]!! + for (i in sourceGroup.indices) { + sourceGroup[i].matchingLine = patchGroup[i] + patchGroup[i].matchingLine = sourceGroup[i] + log.debug("Linked unique matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") + } + } + val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } + log.debug("Finished linking unique matching lines. Matched $matchedCount lines") + return matched.sumOf { sourceLineMap[it]!!.size } + } + + /** + * Links lines that are adjacent to already linked lines and match exactly. + * @param sourceLines The source lines with some established links. + */ + private fun linkAdjacentMatchingLines(sourceLines: List, levenshtein: LevenshteinDistance?): Int { + log.debug("Starting to link adjacent matching lines. Source lines: ${sourceLines.size}") + var foundMatch = true + var matchedLines = 0 + // Continue linking until no more matches are found + while (foundMatch) { + log.debug("Starting new iteration to find adjacent matches") + foundMatch = false + for (sourceLine in sourceLines) { + val patchLine = sourceLine.matchingLine ?: continue // Skip if there's no matching line + + var patchPrev = patchLine.previousLine + while (patchPrev?.previousLine != null && + (patchPrev.type == ADD || normalizeLine(patchPrev.line ?: "").isEmpty()) + ) { + require(patchPrev !== patchPrev.previousLine) + patchPrev = patchPrev.previousLine!! } - matched.forEach { key -> - val sourceGroup = sourceLineMap[key]!! - val patchGroup = patchLineMap[key]!! - for (i in sourceGroup.indices) { - sourceGroup[i].matchingLine = patchGroup[i] - patchGroup[i].matchingLine = sourceGroup[i] - log.debug("Linked unique matching lines: Source[${sourceGroup[i].index}]: ${sourceGroup[i].line} <-> Patch[${patchGroup[i].index}]: ${patchGroup[i].line}") - } + + var sourcePrev = sourceLine.previousLine + while (sourcePrev?.previousLine != null && (normalizeLine(sourcePrev.line ?: "").isEmpty())) { + require(sourcePrev !== sourcePrev.previousLine) + sourcePrev = sourcePrev.previousLine!! } - val matchedCount = matched.sumOf { sourceLineMap[it]!!.size } - log.debug("Finished linking unique matching lines. Matched $matchedCount lines") - return matched.sumOf { sourceLineMap[it]!!.size } - } - /** - * Links lines that are adjacent to already linked lines and match exactly. - * @param sourceLines The source lines with some established links. - */ - private fun linkAdjacentMatchingLines(sourceLines: List, levenshtein: LevenshteinDistance?): Int { - log.debug("Starting to link adjacent matching lines. Source lines: ${sourceLines.size}") - var foundMatch = true - var matchedLines = 0 - // Continue linking until no more matches are found - while (foundMatch) { - log.debug("Starting new iteration to find adjacent matches") - foundMatch = false - for (sourceLine in sourceLines) { - val patchLine = sourceLine.matchingLine ?: continue // Skip if there's no matching line - - var patchPrev = patchLine.previousLine - while (patchPrev?.previousLine != null && - (patchPrev.type == ADD || normalizeLine(patchPrev.line ?: "").isEmpty()) - ) { - require(patchPrev !== patchPrev.previousLine) - patchPrev = patchPrev.previousLine!! - } - - var sourcePrev = sourceLine.previousLine - while (sourcePrev?.previousLine != null && (normalizeLine(sourcePrev.line ?: "").isEmpty())) { - require(sourcePrev !== sourcePrev.previousLine) - sourcePrev = sourcePrev.previousLine!! - } - - if (sourcePrev != null && sourcePrev.matchingLine == null && - patchPrev != null && patchPrev.matchingLine == null - ) { // Skip if there's already a match - if (isMatch(sourcePrev, patchPrev, levenshtein)) { // Check if the lines match exactly - sourcePrev.matchingLine = patchPrev - patchPrev.matchingLine = sourcePrev - foundMatch = true - matchedLines++ - log.debug("Linked adjacent previous lines: Source[${sourcePrev.index}]: ${sourcePrev.line} <-> Patch[${patchPrev.index}]: ${patchPrev.line}") - } - } - - var patchNext = patchLine.nextLine - while (patchNext?.nextLine != null && - (patchNext.type == ADD || normalizeLine(patchNext.line ?: "").isEmpty()) - ) { - require(patchNext !== patchNext.nextLine) - patchNext = patchNext.nextLine!! - } - - var sourceNext = sourceLine.nextLine - while (sourceNext?.nextLine != null && (normalizeLine(sourceNext.line ?: "").isEmpty())) { - require(sourceNext !== sourceNext.nextLine) - sourceNext = sourceNext.nextLine!! - } - - if (sourceNext != null && sourceNext.matchingLine == null && - patchNext != null && patchNext.matchingLine == null - ) { - if (isMatch(sourceNext, patchNext, levenshtein)) { - sourceNext.matchingLine = patchNext - patchNext.matchingLine = sourceNext - foundMatch = true - matchedLines++ - log.debug("Linked adjacent next lines: Source[${sourceNext.index}]: ${sourceNext.line} <-> Patch[${patchNext.index}]: ${patchNext.line}") - } - } - } + if (sourcePrev != null && sourcePrev.matchingLine == null && + patchPrev != null && patchPrev.matchingLine == null + ) { // Skip if there's already a match + if (isMatch(sourcePrev, patchPrev, levenshtein)) { // Check if the lines match exactly + sourcePrev.matchingLine = patchPrev + patchPrev.matchingLine = sourcePrev + foundMatch = true + matchedLines++ + log.debug("Linked adjacent previous lines: Source[${sourcePrev.index}]: ${sourcePrev.line} <-> Patch[${patchPrev.index}]: ${patchPrev.line}") + } } - log.debug("Finished linking adjacent matching lines. Matched $matchedLines lines") - return matchedLines - } - private fun isMatch( - sourcePrev: LineRecord, - patchPrev: LineRecord, - levenshteinDistance: LevenshteinDistance? - ): Boolean { - val normalizeLineSource = normalizeLine(sourcePrev.line!!) - val normalizeLinePatch = normalizeLine(patchPrev.line!!) - var isMatch = normalizeLineSource == normalizeLinePatch - val length = max(normalizeLineSource.length, normalizeLinePatch.length) - if (!isMatch && length > 5 && null != levenshteinDistance) { // Check if the lines are similar using Levenshtein distance - val distance = levenshteinDistance.apply(normalizeLineSource, normalizeLinePatch) - log.debug("Levenshtein distance: $distance") - isMatch = distance <= floor(length / 4.0).toInt() + var patchNext = patchLine.nextLine + while (patchNext?.nextLine != null && + (patchNext.type == ADD || normalizeLine(patchNext.line ?: "").isEmpty()) + ) { + require(patchNext !== patchNext.nextLine) + patchNext = patchNext.nextLine!! } - return isMatch - } - /** - * @param text The text to parse. - * @return The list of line records. - */ - private fun parseLines(text: String): List { - log.debug("Starting to parse lines") - // Create LineRecords for each line and set links between them - val lines = setLinks(text.lines().mapIndexed { index, line -> LineRecord(index, line) }) - // Calculate bracket metrics for each line - calculateLineMetrics(lines) - log.debug("Finished parsing ${lines.size} lines") - return lines - } + var sourceNext = sourceLine.nextLine + while (sourceNext?.nextLine != null && (normalizeLine(sourceNext.line ?: "").isEmpty())) { + require(sourceNext !== sourceNext.nextLine) + sourceNext = sourceNext.nextLine!! + } - /** - * Sets the previous and next line links for a list of line records. - * @return The list with links set. - */ - private fun setLinks(list: List): List { - log.debug("Starting to set links for ${list.size} lines") - for (i in list.indices) { - list[i].previousLine = if (i <= 0) null else { - require(list[i - 1] !== list[i]) - list[i - 1] - } - list[i].nextLine = if (i >= list.size - 1) null else { - require(list[i + 1] !== list[i]) - list[i + 1] - } + if (sourceNext != null && sourceNext.matchingLine == null && + patchNext != null && patchNext.matchingLine == null + ) { + if (isMatch(sourceNext, patchNext, levenshtein)) { + sourceNext.matchingLine = patchNext + patchNext.matchingLine = sourceNext + foundMatch = true + matchedLines++ + log.debug("Linked adjacent next lines: Source[${sourceNext.index}]: ${sourceNext.line} <-> Patch[${patchNext.index}]: ${patchNext.line}") + } } - log.debug("Finished setting links for ${list.size} lines") - return list + } } - - /** - * Parses the patch text into a list of line records, identifying the type of each line (ADD, DELETE, CONTEXT). - * @param text The patch text to parse. - * @return The list of line records with types set. - */ - private fun parsePatchLines(text: String, sourceLines: List): List { - log.debug("Starting to parse patch lines") - val patchLines = setLinks(text.lines().mapIndexed { index, line -> - LineRecord( - index = index, - line = line.let { - when { - it.trimStart().startsWith("+++") -> null - it.trimStart().startsWith("---") -> null - it.trimStart().startsWith("@@") -> null - sourceLines.find { patchLine -> normalizeLine(patchLine.line ?: "") == normalizeLine(it) } != null -> it - it.trimStart().startsWith("+") -> it.trimStart().substring(1) - it.trimStart().startsWith("-") -> it.trimStart().substring(1) - else -> it - } - }, - type = when { - line.trimStart().startsWith("+") -> ADD - line.trimStart().startsWith("-") -> DELETE - else -> CONTEXT - } - ) - }.filter { it.line != null }).toMutableList() - - fixPatchLineOrder(patchLines) - - calculateLineMetrics(patchLines) - log.debug("Finished parsing ${patchLines.size} patch lines") - return patchLines + log.debug("Finished linking adjacent matching lines. Matched $matchedLines lines") + return matchedLines + } + + private fun isMatch( + sourcePrev: LineRecord, + patchPrev: LineRecord, + levenshteinDistance: LevenshteinDistance? + ): Boolean { + val normalizeLineSource = normalizeLine(sourcePrev.line!!) + val normalizeLinePatch = normalizeLine(patchPrev.line!!) + var isMatch = normalizeLineSource == normalizeLinePatch + val length = max(normalizeLineSource.length, normalizeLinePatch.length) + if (!isMatch && length > 5 && null != levenshteinDistance) { // Check if the lines are similar using Levenshtein distance + val distance = levenshteinDistance.apply(normalizeLineSource, normalizeLinePatch) + log.debug("Levenshtein distance: $distance") + isMatch = distance <= floor(length / 4.0).toInt() } - - private fun fixPatchLineOrder(patchLines: MutableList) { - log.debug("Starting to fix patch line order") - // Fixup: Iterate over the patch lines and look for adjacent ADD and DELETE lines; the DELETE should come first... if needed, swap them - var swapped: Boolean - do { - swapped = false - for (i in 0 until patchLines.size - 1) { - if (patchLines[i].type == ADD && patchLines[i + 1].type == DELETE) { - swapped = true - val addLine = patchLines[i] - val deleteLine = patchLines[i + 1] - // Swap records and update pointers - val nextLine = deleteLine.nextLine - val previousLine = addLine.previousLine - - require(addLine !== deleteLine) - if (previousLine === deleteLine) { - throw RuntimeException("previousLine === deleteLine") - } - require(previousLine !== deleteLine) - require(nextLine !== addLine) - require(nextLine !== deleteLine) - deleteLine.nextLine = addLine - addLine.previousLine = deleteLine - deleteLine.previousLine = previousLine - addLine.nextLine = nextLine - patchLines[i] = deleteLine - patchLines[i + 1] = addLine - } - } - } while (swapped) - log.debug("Finished fixing patch line order") + return isMatch + } + + /** + * @param text The text to parse. + * @return The list of line records. + */ + private fun parseLines(text: String): List { + log.debug("Starting to parse lines") + // Create LineRecords for each line and set links between them + val lines = setLinks(text.lines().mapIndexed { index, line -> LineRecord(index, line) }) + // Calculate bracket metrics for each line + calculateLineMetrics(lines) + log.debug("Finished parsing ${lines.size} lines") + return lines + } + + /** + * Sets the previous and next line links for a list of line records. + * @return The list with links set. + */ + private fun setLinks(list: List): List { + log.debug("Starting to set links for ${list.size} lines") + for (i in list.indices) { + list[i].previousLine = if (i <= 0) null else { + require(list[i - 1] !== list[i]) + list[i - 1] + } + list[i].nextLine = if (i >= list.size - 1) null else { + require(list[i + 1] !== list[i]) + list[i + 1] + } } - - /** - * Calculates the metrics for each line, including bracket nesting depth. - * @param lines The list of line records to process. - */ - private fun calculateLineMetrics(lines: List) { - log.debug("Starting to calculate line metrics for ${lines.size} lines") - lines.fold( - Triple(0, 0, 0) - ) { (parenDepth, squareDepth, curlyDepth), lineRecord -> - val updatedDepth = lineRecord.line?.fold(Triple(parenDepth, squareDepth, curlyDepth)) { acc, char -> - when (char) { - '(' -> Triple(acc.first + 1, acc.second, acc.third) - ')' -> Triple(max(0, acc.first - 1), acc.second, acc.third) - '[' -> Triple(acc.first, acc.second + 1, acc.third) - ']' -> Triple(acc.first, max(0, acc.second - 1), acc.third) - '{' -> Triple(acc.first, acc.second, acc.third + 1) - '}' -> Triple(acc.first, acc.second, max(0, acc.third - 1)) - else -> acc - } - } ?: Triple(parenDepth, squareDepth, curlyDepth) - lineRecord.metrics = LineMetrics( - parenthesesDepth = updatedDepth.first, - squareBracketsDepth = updatedDepth.second, - curlyBracesDepth = updatedDepth.third - ) - updatedDepth + log.debug("Finished setting links for ${list.size} lines") + return list + } + + /** + * Parses the patch text into a list of line records, identifying the type of each line (ADD, DELETE, CONTEXT). + * @param text The patch text to parse. + * @return The list of line records with types set. + */ + private fun parsePatchLines(text: String, sourceLines: List): List { + log.debug("Starting to parse patch lines") + val patchLines = setLinks(text.lines().mapIndexed { index, line -> + LineRecord( + index = index, + line = line.let { + when { + it.trimStart().startsWith("+++") -> null + it.trimStart().startsWith("---") -> null + it.trimStart().startsWith("@@") -> null + sourceLines.find { patchLine -> normalizeLine(patchLine.line ?: "") == normalizeLine(it) } != null -> it + it.trimStart().startsWith("+") -> it.trimStart().substring(1) + it.trimStart().startsWith("-") -> it.trimStart().substring(1) + else -> it + } + }, + type = when { + line.trimStart().startsWith("+") -> ADD + line.trimStart().startsWith("-") -> DELETE + else -> CONTEXT } - log.debug("Finished calculating line metrics") - } - - private fun String.lineMetrics(): LineMetrics { - var parenthesesDepth = 0 - var squareBracketsDepth = 0 - var curlyBracesDepth = 0 - - this.forEach { char -> - when (char) { - '(' -> parenthesesDepth++ - ')' -> parenthesesDepth = maxOf(0, parenthesesDepth - 1) - '[' -> squareBracketsDepth++ - ']' -> squareBracketsDepth = maxOf(0, squareBracketsDepth - 1) - '{' -> curlyBracesDepth++ - '}' -> curlyBracesDepth = maxOf(0, curlyBracesDepth - 1) - } + ) + }.filter { it.line != null }).toMutableList() + + fixPatchLineOrder(patchLines) + + calculateLineMetrics(patchLines) + log.debug("Finished parsing ${patchLines.size} patch lines") + return patchLines + } + + private fun fixPatchLineOrder(patchLines: MutableList) { + log.debug("Starting to fix patch line order") + // Fixup: Iterate over the patch lines and look for adjacent ADD and DELETE lines; the DELETE should come first... if needed, swap them + var swapped: Boolean + do { + swapped = false + for (i in 0 until patchLines.size - 1) { + if (patchLines[i].type == ADD && patchLines[i + 1].type == DELETE) { + swapped = true + val addLine = patchLines[i] + val deleteLine = patchLines[i + 1] + // Swap records and update pointers + val nextLine = deleteLine.nextLine + val previousLine = addLine.previousLine + + require(addLine !== deleteLine) + if (previousLine === deleteLine) { + throw RuntimeException("previousLine === deleteLine") + } + require(previousLine !== deleteLine) + require(nextLine !== addLine) + require(nextLine !== deleteLine) + deleteLine.nextLine = addLine + addLine.previousLine = deleteLine + deleteLine.previousLine = previousLine + addLine.nextLine = nextLine + patchLines[i] = deleteLine + patchLines[i + 1] = addLine + } + } + } while (swapped) + log.debug("Finished fixing patch line order") + } + + /** + * Calculates the metrics for each line, including bracket nesting depth. + * @param lines The list of line records to process. + */ + private fun calculateLineMetrics(lines: List) { + log.debug("Starting to calculate line metrics for ${lines.size} lines") + lines.fold( + Triple(0, 0, 0) + ) { (parenDepth, squareDepth, curlyDepth), lineRecord -> + val updatedDepth = lineRecord.line?.fold(Triple(parenDepth, squareDepth, curlyDepth)) { acc, char -> + when (char) { + '(' -> Triple(acc.first + 1, acc.second, acc.third) + ')' -> Triple(max(0, acc.first - 1), acc.second, acc.third) + '[' -> Triple(acc.first, acc.second + 1, acc.third) + ']' -> Triple(acc.first, max(0, acc.second - 1), acc.third) + '{' -> Triple(acc.first, acc.second, acc.third + 1) + '}' -> Triple(acc.first, acc.second, max(0, acc.third - 1)) + else -> acc } - return LineMetrics( - parenthesesDepth = parenthesesDepth, - squareBracketsDepth = squareBracketsDepth, - curlyBracesDepth = curlyBracesDepth - ) + } ?: Triple(parenDepth, squareDepth, curlyDepth) + lineRecord.metrics = LineMetrics( + parenthesesDepth = updatedDepth.first, + squareBracketsDepth = updatedDepth.second, + curlyBracesDepth = updatedDepth.third + ) + updatedDepth } + log.debug("Finished calculating line metrics") + } + + private fun String.lineMetrics(): LineMetrics { + var parenthesesDepth = 0 + var squareBracketsDepth = 0 + var curlyBracesDepth = 0 + + this.forEach { char -> + when (char) { + '(' -> parenthesesDepth++ + ')' -> parenthesesDepth = maxOf(0, parenthesesDepth - 1) + '[' -> squareBracketsDepth++ + ']' -> squareBracketsDepth = maxOf(0, squareBracketsDepth - 1) + '{' -> curlyBracesDepth++ + '}' -> curlyBracesDepth = maxOf(0, curlyBracesDepth - 1) + } + } + return LineMetrics( + parenthesesDepth = parenthesesDepth, + squareBracketsDepth = squareBracketsDepth, + curlyBracesDepth = curlyBracesDepth + ) + } - private val log = LoggerFactory.getLogger(IterativePatchUtil::class.java) + private val log = LoggerFactory.getLogger(IterativePatchUtil::class.java) } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/diff/PatchResult.kt b/webui/src/main/kotlin/com/simiacryptus/diff/PatchResult.kt index 37502e2d..ebb1ddb3 100644 --- a/webui/src/main/kotlin/com/simiacryptus/diff/PatchResult.kt +++ b/webui/src/main/kotlin/com/simiacryptus/diff/PatchResult.kt @@ -1,6 +1,6 @@ package com.simiacryptus.diff data class PatchResult( - val newCode: String, - val isValid: Boolean, + val newCode: String, + val isValid: Boolean, ) \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt index c3605cf6..f2e90fc8 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/AgentPatterns.kt @@ -5,51 +5,51 @@ import java.util.* object AgentPatterns { - fun displayMapInTabs( - map: Map, - ui: ApplicationInterface? = null, - split: Boolean = map.entries.map { it.value.length + it.key.length }.sum() > 10000 - ): String = if (split && ui != null) { - val tasks = map.entries.map { (key, value) -> - key to ui.newTask(root = false) - }.toMap() - ui.socketManager?.scheduledThreadPoolExecutor?.schedule({ - tasks.forEach { (key, task) -> - task.complete(map[key]!!) - } - }, 200, java.util.concurrent.TimeUnit.MILLISECONDS) - displayMapInTabs(tasks.mapValues { it.value.placeholder }, ui = ui, split = false) - } else { - """ + fun displayMapInTabs( + map: Map, + ui: ApplicationInterface? = null, + split: Boolean = map.entries.map { it.value.length + it.key.length }.sum() > 10000 + ): String = if (split && ui != null) { + val tasks = map.entries.map { (key, value) -> + key to ui.newTask(root = false) + }.toMap() + ui.socketManager?.scheduledThreadPoolExecutor?.schedule({ + tasks.forEach { (key, task) -> + task.complete(map[key]!!) + } + }, 200, java.util.concurrent.TimeUnit.MILLISECONDS) + displayMapInTabs(tasks.mapValues { it.value.placeholder }, ui = ui, split = false) + } else { + """
${ - map.keys.joinToString("\n") { key -> - """""" - } - } + map.keys.joinToString("\n") { key -> + """""" + } + }
${ - map.entries.withIndex().joinToString("\n") { (idx, t) -> - val (key, value) = t - """ + map.entries.withIndex().joinToString("\n") { (idx, t) -> + val (key, value) = t + """
"" - } - }" data-tab="$key"> + when { + idx == 0 -> " active" + else -> "" + } + }" data-tab="$key"> ${value/*.indent(" ")*/}
""" - } - } + } + }
""" - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/Discussable.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/Discussable.kt index d935531a..494109bd 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/Discussable.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/Discussable.kt @@ -10,227 +10,227 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference class Discussable( - private val task: SessionTask, - private val userMessage: () -> String, - private val initialResponse: (String) -> T, - private val outputFn: (T) -> String, - private val ui: ApplicationInterface, - private val reviseResponse: (List>) -> T, - private val atomicRef: AtomicReference = AtomicReference(), - private val semaphore: Semaphore = Semaphore(0), - private val heading: String + private val task: SessionTask, + private val userMessage: () -> String, + private val initialResponse: (String) -> T, + private val outputFn: (T) -> String, + private val ui: ApplicationInterface, + private val reviseResponse: (List>) -> T, + private val atomicRef: AtomicReference = AtomicReference(), + private val semaphore: Semaphore = Semaphore(0), + private val heading: String ) : Callable { - val tabs = object : TabbedDisplay(task) { - override fun renderTabButtons() = """ + val tabs = object : TabbedDisplay(task) { + override fun renderTabButtons() = """
${ - tabs.withIndex().joinToString("\n") - { (index: Int, t: Pair) -> - """""" - } - } + tabs.withIndex().joinToString("\n") + { (index: Int, t: Pair) -> + """""" + } + } ${ - ui.hrefLink("♻") { - val newTask = ui.newTask(false) - val header = newTask.header("Retrying...") - val idx: Int = size - this.set(label(idx), newTask.placeholder) - main(idx, newTask) - //this.selectedTab = idx - header?.clear() - newTask.complete() - } - } + ui.hrefLink("♻") { + val newTask = ui.newTask(false) + val header = newTask.header("Retrying...") + val idx: Int = size + this.set(label(idx), newTask.placeholder) + main(idx, newTask) + //this.selectedTab = idx + header?.clear() + newTask.complete() + } + }
""" - } - private val acceptGuard = AtomicBoolean(false) + } + private val acceptGuard = AtomicBoolean(false) - private fun main(tabIndex: Int, task: SessionTask) { - log.info("Starting main function for tabIndex: $tabIndex") - try { - val history = mutableListOf>() - val userMessage = userMessage() - log.info("User message: $userMessage") - history.add(userMessage to Role.user) - val design = initialResponse(userMessage) - log.info("Initial response generated: $design") - val rendered = outputFn(design) - log.info("Rendered output: $rendered") - history.add(rendered to Role.assistant) - val tabContent = task.add(rendered)!! - val feedbackForm = feedbackForm(tabIndex, tabContent, design, history, task) - tabContent.append("\n" + feedbackForm.placeholder) - task.complete() - } catch (e: Throwable) { - log.error("Error in discussable", e) - task.error(ui, e) - task.complete(ui.hrefLink("🔄 Retry") { - main(tabIndex = tabIndex, task = task) - }) - } + private fun main(tabIndex: Int, task: SessionTask) { + log.info("Starting main function for tabIndex: $tabIndex") + try { + val history = mutableListOf>() + val userMessage = userMessage() + log.info("User message: $userMessage") + history.add(userMessage to Role.user) + val design = initialResponse(userMessage) + log.info("Initial response generated: $design") + val rendered = outputFn(design) + log.info("Rendered output: $rendered") + history.add(rendered to Role.assistant) + val tabContent = task.add(rendered)!! + val feedbackForm = feedbackForm(tabIndex, tabContent, design, history, task) + tabContent.append("\n" + feedbackForm.placeholder) + task.complete() + } catch (e: Throwable) { + log.error("Error in discussable", e) + task.error(ui, e) + task.complete(ui.hrefLink("🔄 Retry") { + main(tabIndex = tabIndex, task = task) + }) } + } - private fun feedbackForm( - tabIndex: Int?, - tabContent: StringBuilder, - design: T, - history: List>, - task: SessionTask, - ) = ui.newTask(false).apply { - log.info("Creating feedback form for tabIndex: $tabIndex") - val feedbackSB = add("
")!! - feedbackSB.clear() - feedbackSB.append( - """ + private fun feedbackForm( + tabIndex: Int?, + tabContent: StringBuilder, + design: T, + history: List>, + task: SessionTask, + ) = ui.newTask(false).apply { + log.info("Creating feedback form for tabIndex: $tabIndex") + val feedbackSB = add("
")!! + feedbackSB.clear() + feedbackSB.append( + """
${acceptLink(tabIndex, tabContent, design, feedbackSB, feedbackTask = this)}
${textInput(design, tabContent, history, task, feedbackSB, feedbackTask = this)} """ - ) - complete() - } + ) + complete() + } + + private fun acceptLink( + tabIndex: Int?, + tabContent: StringBuilder, + design: T, + feedbackSB: StringBuilder, + feedbackTask: SessionTask, + ) = ui.hrefLink("Accept", classname = "href-link cmd-button") { + log.info("Accept link clicked for tabIndex: $tabIndex") + feedbackSB.clear() + feedbackTask.complete() + accept(tabIndex, tabContent, design) + } - private fun acceptLink( - tabIndex: Int?, - tabContent: StringBuilder, - design: T, - feedbackSB: StringBuilder, - feedbackTask: SessionTask, - ) = ui.hrefLink("Accept", classname = "href-link cmd-button") { - log.info("Accept link clicked for tabIndex: $tabIndex") + private fun textInput( + design: T, + tabContent: StringBuilder, + history: List>, + task: SessionTask, + feedbackSB: StringBuilder, + feedbackTask: SessionTask, + ): String { + val feedbackGuard = AtomicBoolean(false) + return ui.textInput { userResponse -> + log.info("User response received: $userResponse") + if (feedbackGuard.getAndSet(true)) return@textInput + val prev = feedbackSB.toString() + try { feedbackSB.clear() feedbackTask.complete() - accept(tabIndex, tabContent, design) + feedback(tabContent, userResponse, history, design, task) + } catch (e: Exception) { + log.error("Error processing user feedback", e) + task.error(ui, e) + feedbackSB.set(prev) + feedbackTask.complete() + throw e + } finally { + feedbackGuard.set(false) + } } + } - private fun textInput( - design: T, - tabContent: StringBuilder, - history: List>, - task: SessionTask, - feedbackSB: StringBuilder, - feedbackTask: SessionTask, - ): String { - val feedbackGuard = AtomicBoolean(false) - return ui.textInput { userResponse -> - log.info("User response received: $userResponse") - if (feedbackGuard.getAndSet(true)) return@textInput - val prev = feedbackSB.toString() - try { - feedbackSB.clear() - feedbackTask.complete() - feedback(tabContent, userResponse, history, design, task) - } catch (e: Exception) { - log.error("Error processing user feedback", e) - task.error(ui, e) - feedbackSB.set(prev) - feedbackTask.complete() - throw e - } finally { - feedbackGuard.set(false) - } - } + private fun feedback( + tabContent: StringBuilder, + userResponse: String, + history: List>, + design: T, + task: SessionTask, + ) { + log.info("Processing feedback for user response: $userResponse") + var history = history + history = history + (userResponse to Role.user) + val newValue = (tabContent.toString() + + "
" + + renderMarkdown(userResponse, ui = ui) + + "
") + tabContent.set(newValue) + val stringBuilder = task.add("Processing...") + tabs.update() + val newDesign = reviseResponse(history) + log.info("Revised design: $newDesign") + val newTask = ui.newTask(root = false) + tabContent.set(newValue + "\n" + newTask.placeholder) + tabs.update() + stringBuilder?.clear() + task.complete() + Retryable(ui, newTask) { + outputFn(newDesign) + "\n" + feedbackForm( + tabIndex = null, + tabContent = it, + design = newDesign, + history = history, + task = newTask + ).placeholder } + } - private fun feedback( - tabContent: StringBuilder, - userResponse: String, - history: List>, - design: T, - task: SessionTask, - ) { - log.info("Processing feedback for user response: $userResponse") - var history = history - history = history + (userResponse to Role.user) - val newValue = (tabContent.toString() - + "
" - + renderMarkdown(userResponse, ui = ui) - + "
") - tabContent.set(newValue) - val stringBuilder = task.add("Processing...") - tabs.update() - val newDesign = reviseResponse(history) - log.info("Revised design: $newDesign") - val newTask = ui.newTask(root = false) - tabContent.set(newValue + "\n" + newTask.placeholder) - tabs.update() - stringBuilder?.clear() - task.complete() - Retryable(ui, newTask) { - outputFn(newDesign) + "\n" + feedbackForm( - tabIndex = null, - tabContent = it, - design = newDesign, - history = history, - task = newTask - ).placeholder - } + private fun accept(tabIndex: Int?, tabContent: StringBuilder, design: T) { + log.info("Accepting design for tabIndex: $tabIndex") + if (acceptGuard.getAndSet(true)) { + return } - - private fun accept(tabIndex: Int?, tabContent: StringBuilder, design: T) { - log.info("Accepting design for tabIndex: $tabIndex") - if (acceptGuard.getAndSet(true)) { - return - } - try { - //if (null != tabIndex) tabs.selectedTab = tabIndex - tabContent.apply { - val prevTab = toString() - set(prevTab) - tabs.update() - } - } catch (e: Exception) { - log.error("Error accepting design", e) - task.error(ui, e) - acceptGuard.set(false) - throw e - } - atomicRef.set(design) - semaphore.release() + try { + //if (null != tabIndex) tabs.selectedTab = tabIndex + tabContent.apply { + val prevTab = toString() + set(prevTab) + tabs.update() + } + } catch (e: Exception) { + log.error("Error accepting design", e) + task.error(ui, e) + acceptGuard.set(false) + throw e } + atomicRef.set(design) + semaphore.release() + } - override fun call(): T { - try { - //log.info("Calling Discussable with heading: $heading") - task.echo(heading) - val idx = tabs.size - val newTask = ui.newTask(false) - val header = newTask.header("Processing...") - tabs[tabs.label(idx)] = newTask.placeholder - try { - main(idx, newTask) - //tabs.selectedTab = idx - semaphore.acquire() - } catch (e: Throwable) { - log.error("Error in main function", e) - task.error(ui, e) - } finally { - header?.clear() - newTask.complete() - } - log.info("Returning result from Discussable") - return atomicRef.get() - } catch (e: Exception) { - log.warn( - """ + override fun call(): T { + try { + //log.info("Calling Discussable with heading: $heading") + task.echo(heading) + val idx = tabs.size + val newTask = ui.newTask(false) + val header = newTask.header("Processing...") + tabs[tabs.label(idx)] = newTask.placeholder + try { + main(idx, newTask) + //tabs.selectedTab = idx + semaphore.acquire() + } catch (e: Throwable) { + log.error("Error in main function", e) + task.error(ui, e) + } finally { + header?.clear() + newTask.complete() + } + log.info("Returning result from Discussable") + return atomicRef.get() + } catch (e: Exception) { + log.warn( + """ Error in Discussable ${e.message} """, e - ) - task.error(ui, e) - return null as T - } + ) + task.error(ui, e) + return null as T } + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(Discussable::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(Discussable::class.java) + } } fun java.lang.StringBuilder.set(newValue: String) { - clear() - append(newValue) + clear() + append(newValue) } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt index 0985d65b..77376467 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/Retryable.kt @@ -4,40 +4,40 @@ import com.simiacryptus.skyenet.webui.application.ApplicationInterface import com.simiacryptus.skyenet.webui.session.SessionTask open class Retryable( - val ui: ApplicationInterface, - task: SessionTask, - val process: (StringBuilder) -> String + val ui: ApplicationInterface, + task: SessionTask, + val process: (StringBuilder) -> String ) : TabbedDisplay(task) { - init { - init() - } + init { + init() + } - open fun init() { - val tabLabel = label(size) - set(tabLabel, SessionTask.spinner) - set(tabLabel, process(container)) - } + open fun init() { + val tabLabel = label(size) + set(tabLabel, SessionTask.spinner) + set(tabLabel, process(container)) + } - override fun renderTabButtons(): String = """ + override fun renderTabButtons(): String = """
${ - tabs.withIndex().joinToString("\n") { (index, _) -> - val tabId = "$index" - """""" - } + tabs.withIndex().joinToString("\n") { (index, _) -> + val tabId = "$index" + """""" } + } ${ - ui.hrefLink("♻") { - val idx = tabs.size - val label = label(idx) - val content = StringBuilder("Retrying..." + SessionTask.spinner) - tabs.add(label to content) - update() - val newResult = process(content) - content.clear() - set(label, newResult) - } + ui.hrefLink("♻") { + val idx = tabs.size + val label = label(idx) + val content = StringBuilder("Retrying..." + SessionTask.spinner) + tabs.add(label to content) + update() + val newResult = process(content) + content.clear() + set(label, newResult) } + }
""" diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt index d4755b80..c4069767 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/TabbedDisplay.kt @@ -4,95 +4,95 @@ import com.simiacryptus.skyenet.webui.session.SessionTask import java.util.* open class TabbedDisplay( - val task: SessionTask, - val tabs: MutableList> = mutableListOf(), + val task: SessionTask, + val tabs: MutableList> = mutableListOf(), ) { - var selectedTab: Int = 0 + var selectedTab: Int = 0 - companion object { - val log = org.slf4j.LoggerFactory.getLogger(TabbedDisplay::class.java) - } + companion object { + val log = org.slf4j.LoggerFactory.getLogger(TabbedDisplay::class.java) + } - val size: Int get() = tabs.size - open fun render() = if (tabs.isEmpty()) "
" else """ + val size: Int get() = tabs.size + open fun render() = if (tabs.isEmpty()) "
" else """
${renderTabButtons()} ${ - tabs.toTypedArray().withIndex().joinToString("\n") - { (idx, t) -> renderContentTab(t, idx) } - } + tabs.toTypedArray().withIndex().joinToString("\n") + { (idx, t) -> renderContentTab(t, idx) } + }
""" - val container: StringBuilder by lazy { - log.debug("Initializing container with rendered content") - task.add(render())!! - } + val container: StringBuilder by lazy { + log.debug("Initializing container with rendered content") + task.add(render())!! + } - open fun renderTabButtons() = """ + open fun renderTabButtons() = """
${ - tabs.toTypedArray().withIndex().joinToString("\n") { (idx, pair) -> - if (idx == selectedTab) { - """""" - } else { - """""" - } - } - }
+ tabs.toTypedArray().withIndex().joinToString("\n") { (idx, pair) -> + if (idx == selectedTab) { + """""" + } else { + """""" + } + } + }
""" - open fun renderContentTab(t: Pair, idx: Int) = """ + open fun renderContentTab(t: Pair, idx: Int) = """
"" - } - }" data-tab="$idx">${t.second}
""" + when { + idx == selectedTab -> "active" + else -> "" + } + }" data-tab="$idx">${t.second}
""" - operator fun get(i: String) = tabs.toMap()[i] - operator fun set(name: String, content: String) = - when (val index = find(name)) { - null -> { - log.debug("Adding new tab: $name") - val stringBuilder = StringBuilder(content) - tabs.add(name to stringBuilder) - update() - stringBuilder - } + operator fun get(i: String) = tabs.toMap()[i] + operator fun set(name: String, content: String) = + when (val index = find(name)) { + null -> { + log.debug("Adding new tab: $name") + val stringBuilder = StringBuilder(content) + tabs.add(name to stringBuilder) + update() + stringBuilder + } - else -> { - log.debug("Updating existing tab: $name") - val stringBuilder = tabs[index].second - stringBuilder.clear() - stringBuilder.append(content) - update() - stringBuilder - } - } + else -> { + log.debug("Updating existing tab: $name") + val stringBuilder = tabs[index].second + stringBuilder.clear() + stringBuilder.append(content) + update() + stringBuilder + } + } - fun find(name: String) = tabs.withIndex().firstOrNull { it.value.first == name }?.index + fun find(name: String) = tabs.withIndex().firstOrNull { it.value.first == name }?.index - open fun label(i: Int): String { - return "${tabs.size + 1}" - } + open fun label(i: Int): String { + return "${tabs.size + 1}" + } - open fun clear() { - log.debug("Clearing all tabs") - tabs.clear() - update() - } + open fun clear() { + log.debug("Clearing all tabs") + tabs.clear() + update() + } - open fun update() { - log.debug("Updating container content") - synchronized(container) { - if (tabs.isNotEmpty() && (selectedTab < 0 || selectedTab >= tabs.size)) { - selectedTab = 0 - } - container.clear() - container.append(render()) - } - task.complete() + open fun update() { + log.debug("Updating container content") + synchronized(container) { + if (tabs.isNotEmpty() && (selectedTab < 0 || selectedTab >= tabs.size)) { + selectedTab = 0 + } + container.clear() + container.append(render()) } + task.complete() + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/CodingAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/CodingAgent.kt index bb4ce956..e400e97e 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/CodingAgent.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/CodingAgent.kt @@ -25,314 +25,314 @@ import kotlin.reflect.KClass open class CodingAgent( - val api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - val ui: ApplicationInterface, - interpreter: KClass, - val symbols: Map, - temperature: Double = 0.1, - val details: String? = null, - val model: TextModel, - private val mainTask: SessionTask, - val actorMap: Map = mapOf( - ActorTypes.CodingActor to CodingActor( - interpreter, - symbols = symbols, - temperature = temperature, - details = details, - model = model - ) - ), + val api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + val ui: ApplicationInterface, + interpreter: KClass, + val symbols: Map, + temperature: Double = 0.1, + val details: String? = null, + val model: TextModel, + private val mainTask: SessionTask, + val actorMap: Map = mapOf( + ActorTypes.CodingActor to CodingActor( + interpreter, + symbols = symbols, + temperature = temperature, + details = details, + model = model + ) + ), ) : ActorSystem(actorMap.map { it.key.name to it.value }.toMap(), dataStorage, user, session) { - enum class ActorTypes { - CodingActor - } + enum class ActorTypes { + CodingActor + } - open val actor by lazy { - getActor(ActorTypes.CodingActor) as CodingActor - } + open val actor by lazy { + getActor(ActorTypes.CodingActor) as CodingActor + } - open val canPlay by lazy { - ApplicationServices.authorizationManager.isAuthorized( - this::class.java, - user, - OperationType.Execute - ) - } + open val canPlay by lazy { + ApplicationServices.authorizationManager.isAuthorized( + this::class.java, + user, + OperationType.Execute + ) + } - fun start( - userMessage: String, - ) { - try { - mainTask.echo(renderMarkdown(userMessage, ui = ui)) - val codeRequest = codeRequest(listOf(userMessage to ApiModel.Role.user)) - start(codeRequest, mainTask) - } catch (e: Throwable) { - log.warn("Error", e) - mainTask.error(ui, e) - } + fun start( + userMessage: String, + ) { + try { + mainTask.echo(renderMarkdown(userMessage, ui = ui)) + val codeRequest = codeRequest(listOf(userMessage to ApiModel.Role.user)) + start(codeRequest, mainTask) + } catch (e: Throwable) { + log.warn("Error", e) + mainTask.error(ui, e) } + } - fun start( - codeRequest: CodingActor.CodeRequest, - task: SessionTask = mainTask, - ) { - val newTask = ui.newTask(root = false) - task.complete(newTask.placeholder) - Retryable(ui, newTask) { - val newTask = ui.newTask(root = false) - ui.socketManager?.scheduledThreadPoolExecutor!!.schedule({ - ui.socketManager.pool?.submit { - val statusSB = newTask.add("Running...") - displayCode(newTask, codeRequest) - statusSB?.clear() - newTask.complete() - } - }, 100, TimeUnit.MILLISECONDS) - newTask.placeholder + fun start( + codeRequest: CodingActor.CodeRequest, + task: SessionTask = mainTask, + ) { + val newTask = ui.newTask(root = false) + task.complete(newTask.placeholder) + Retryable(ui, newTask) { + val newTask = ui.newTask(root = false) + ui.socketManager?.scheduledThreadPoolExecutor!!.schedule({ + ui.socketManager.pool?.submit { + val statusSB = newTask.add("Running...") + displayCode(newTask, codeRequest) + statusSB?.clear() + newTask.complete() } + }, 100, TimeUnit.MILLISECONDS) + newTask.placeholder } + } - open fun codeRequest(messages: List>) = - CodingActor.CodeRequest(messages) + open fun codeRequest(messages: List>) = + CodingActor.CodeRequest(messages) - fun displayCode( - task: SessionTask, - codeRequest: CodingActor.CodeRequest, - ) { - try { - val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = api as ChatClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = api) - } - displayCodeAndFeedback(task, codeRequest, codeResponse) - } catch (e: Throwable) { - log.warn("Error", e) - } + fun displayCode( + task: SessionTask, + codeRequest: CodingActor.CodeRequest, + ) { + try { + val lastUserMessage = codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = api as ChatClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = api) + } + displayCodeAndFeedback(task, codeRequest, codeResponse) + } catch (e: Throwable) { + log.warn("Error", e) } + } - protected fun displayCodeAndFeedback( - task: SessionTask, - codeRequest: CodingActor.CodeRequest, - response: CodeResult, - ) { - try { - displayCode(task, response) - displayFeedback(task, append(codeRequest, response), response) - } catch (e: Throwable) { - task.error(ui, e) - log.warn("Error", e) - } + protected fun displayCodeAndFeedback( + task: SessionTask, + codeRequest: CodingActor.CodeRequest, + response: CodeResult, + ) { + try { + displayCode(task, response) + displayFeedback(task, append(codeRequest, response), response) + } catch (e: Throwable) { + task.error(ui, e) + log.warn("Error", e) } + } - fun append( - codeRequest: CodingActor.CodeRequest, - response: CodeResult - ) = codeRequest( - messages = codeRequest.messages + - listOf( - response.code to ApiModel.Role.assistant, - ).filter { it.first.isNotBlank() } - ) + fun append( + codeRequest: CodingActor.CodeRequest, + response: CodeResult + ) = codeRequest( + messages = codeRequest.messages + + listOf( + response.code to ApiModel.Role.assistant, + ).filter { it.first.isNotBlank() } + ) - fun displayCode( - task: SessionTask, - response: CodeResult - ) { - task.hideable( - ui, - renderMarkdown( - response.renderedResponse ?: - //language=Markdown - "```${actor.language.lowercase(Locale.getDefault())}\n${response.code.trim()}\n```", ui = ui - ) - ) - } + fun displayCode( + task: SessionTask, + response: CodeResult + ) { + task.hideable( + ui, + renderMarkdown( + response.renderedResponse ?: + //language=Markdown + "```${actor.language.lowercase(Locale.getDefault())}\n${response.code.trim()}\n```", ui = ui + ) + ) + } - open fun displayFeedback( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + open fun displayFeedback( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${if (!canPlay) "" else playButton(task, request, response, formText) { formHandle!! }} |
|${reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - protected fun reviseMsg( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = ui.textInput { feedback -> - responseAction(task, "Revising...", formHandle(), formText) { - feedback(task, feedback, request, response) - } + protected fun reviseMsg( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = ui.textInput { feedback -> + responseAction(task, "Revising...", formHandle(), formText) { + feedback(task, feedback, request, response) } + } - protected fun regenButton( - task: SessionTask, - request: CodingActor.CodeRequest, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = "" - - protected fun playButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = if (!canPlay) "" else - ui.hrefLink("▶", "href-link play-button") { - responseAction(task, "Running...", formHandle(), formText) { - execute(task, response, request) - } - } + protected fun regenButton( + task: SessionTask, + request: CodingActor.CodeRequest, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = "" - protected open fun responseAction( - task: SessionTask, - message: String, - formHandle: StringBuilder?, - formText: StringBuilder, - fn: () -> Unit = {} - ) { - formHandle?.clear() - val header = task.header(message) - try { - fn() - } finally { - header?.clear() - revertButton(task, formHandle, formText) - } + protected fun playButton( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = if (!canPlay) "" else + ui.hrefLink("▶", "href-link play-button") { + responseAction(task, "Running...", formHandle(), formText) { + execute(task, response, request) + } } - protected open fun revertButton( - task: SessionTask, - formHandle: StringBuilder?, - formText: StringBuilder - ): StringBuilder? { - var revertButton: StringBuilder? = null - revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button") { - revertButton?.clear() - formHandle?.append(formText) - task.complete() - }) - return revertButton + protected open fun responseAction( + task: SessionTask, + message: String, + formHandle: StringBuilder?, + formText: StringBuilder, + fn: () -> Unit = {} + ) { + formHandle?.clear() + val header = task.header(message) + try { + fn() + } finally { + header?.clear() + revertButton(task, formHandle, formText) } + } - protected open fun feedback( - task: SessionTask, - feedback: String, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - try { - task.echo(renderMarkdown(feedback, ui = ui)) - start(codeRequest = codeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() }.map { it.first to it.second } - ), task = task) - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) - } + protected open fun revertButton( + task: SessionTask, + formHandle: StringBuilder?, + formText: StringBuilder + ): StringBuilder? { + var revertButton: StringBuilder? = null + revertButton = task.complete(ui.hrefLink("↩", "href-link regen-button") { + revertButton?.clear() + formHandle?.append(formText) + task.complete() + }) + return revertButton + } + + protected open fun feedback( + task: SessionTask, + feedback: String, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + try { + task.echo(renderMarkdown(feedback, ui = ui)) + start(codeRequest = codeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + feedback to ApiModel.Role.user, + ).filter { it.first.isNotBlank() }.map { it.first to it.second } + ), task = task) + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) } + } - protected open fun execute( - task: SessionTask, - response: CodeResult, - request: CodingActor.CodeRequest, - ) { - try { - val result = execute(task, response) - displayFeedback(task, codeRequest( - messages = request.messages + - listOf( - "Running...\n\n$result" to ApiModel.Role.assistant, - ).filter { it.first.isNotBlank() } - ), response) - } catch (e: Throwable) { - handleExecutionError(e, task, request, response) - } + protected open fun execute( + task: SessionTask, + response: CodeResult, + request: CodingActor.CodeRequest, + ) { + try { + val result = execute(task, response) + displayFeedback(task, codeRequest( + messages = request.messages + + listOf( + "Running...\n\n$result" to ApiModel.Role.assistant, + ).filter { it.first.isNotBlank() } + ), response) + } catch (e: Throwable) { + handleExecutionError(e, task, request, response) } + } - protected open fun handleExecutionError( - e: Throwable, - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult - ) { - val message = when { - e is ValidatedObject.ValidationError -> renderMarkdown(e.message ?: "", ui = ui) - e is CodingActor.FailedToImplementException -> renderMarkdown( - """ + protected open fun handleExecutionError( + e: Throwable, + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult + ) { + val message = when { + e is ValidatedObject.ValidationError -> renderMarkdown(e.message ?: "", ui = ui) + e is CodingActor.FailedToImplementException -> renderMarkdown( + """ |**Failed to Implement** | |${e.message} | |""".trimMargin(), ui = ui - ) + ) - else -> renderMarkdown( - """ + else -> renderMarkdown( + """ |**Error `${e.javaClass.name}`** | |```text |${e.stackTraceToString()/*.indent(" ")*/} |``` |""".trimMargin(), ui = ui - ) - } - task.add(message, true, "div", "error") - displayCode(task, CodingActor.CodeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - message to ApiModel.Role.system, - ).filter { it.first.isNotBlank() } - )) + ) } + task.add(message, true, "div", "error") + displayCode(task, CodingActor.CodeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + message to ApiModel.Role.system, + ).filter { it.first.isNotBlank() } + )) + } - fun execute( - task: SessionTask, - response: CodeResult - ): String { - val resultValue = response.result.resultValue - val resultOutput = response.result.resultOutput - val result = when { - resultValue.isBlank() || resultValue.trim().lowercase() == "null" -> """ + fun execute( + task: SessionTask, + response: CodeResult + ): String { + val resultValue = response.result.resultValue + val resultOutput = response.result.resultOutput + val result = when { + resultValue.isBlank() || resultValue.trim().lowercase() == "null" -> """ |# Output |```text |${resultOutput.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` """.trimMargin() - else -> """ + else -> """ |# Result |``` |${resultValue.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} @@ -343,12 +343,12 @@ open class CodingAgent( |${resultOutput.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }} |``` """.trimMargin() - } - task.add(renderMarkdown(result, ui = ui)) - return result } + task.add(renderMarkdown(result, ui = ui)) + return result + } - companion object { - private val log = LoggerFactory.getLogger(CodingAgent::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(CodingAgent::class.java) + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/ShellToolAgent.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/ShellToolAgent.kt index f004fef9..11ac73fe 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/ShellToolAgent.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/code/ShellToolAgent.kt @@ -35,54 +35,54 @@ import java.io.File import kotlin.reflect.KClass private val String.escapeQuotedString: String - get() = replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("$", "\\$") + get() = replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("$", "\\$") abstract class ShellToolAgent( - api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - ui: ApplicationInterface, - interpreter: KClass, - symbols: Map, - temperature: Double = 0.1, - details: String? = null, - model: ChatModel, - actorMap: Map = mapOf( - ActorTypes.CodingActor to CodingActor( - interpreter, - symbols = symbols, - temperature = temperature, - details = details, - model = model - ) - ), - mainTask: SessionTask = ui.newTask(), + api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + ui: ApplicationInterface, + interpreter: KClass, + symbols: Map, + temperature: Double = 0.1, + details: String? = null, + model: ChatModel, + actorMap: Map = mapOf( + ActorTypes.CodingActor to CodingActor( + interpreter, + symbols = symbols, + temperature = temperature, + details = details, + model = model + ) + ), + mainTask: SessionTask = ui.newTask(), ) : CodingAgent( - api, - dataStorage, - session, - user, - ui, - interpreter, - symbols, - temperature, - details, - model, - mainTask, - actorMap + api, + dataStorage, + session, + user, + ui, + interpreter, + symbols, + temperature, + details, + model, + mainTask, + actorMap ) { - override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + override fun displayFeedback(task: SessionTask, request: CodingActor.CodeRequest, response: CodeResult) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${if (!canPlay) "" else playButton(task, request, response, formText) { formHandle!! }} |${super.regenButton(task, request, formText) { formHandle!! }} @@ -90,41 +90,41 @@ abstract class ShellToolAgent( |
|${super.reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - private var lastResult: String? = null + private var lastResult: String? = null - private fun createToolButton( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodeResult, - formText: StringBuilder, - formHandle: () -> StringBuilder - ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button") { - val task = ui.newTask() - responseAction(task, "Exporting...", formHandle(), formText) { - displayCodeFeedback( - task, schemaActor(), request.copy( - messages = listOf( - response.code to ApiModel.Role.assistant, - "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user - ) - ) - ) { schemaCode -> - val command = actor.symbols.get("command")?.let { command -> - when (command) { - is String -> command.split(" ") - is List<*> -> command.map { it.toString() } - else -> throw IllegalArgumentException("Invalid command: $command") - } - } ?: listOf("bash") - val cwd = actor.symbols.get("workingDir")?.toString()?.let { File(it) } ?: File(".") - val env = actor.symbols.get("env")?.let { env -> (env as Map) } ?: mapOf() - val codePrefix = """ + private fun createToolButton( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodeResult, + formText: StringBuilder, + formHandle: () -> StringBuilder + ) = ui.hrefLink("\uD83D\uDCE4", "href-link regen-button") { + val task = ui.newTask() + responseAction(task, "Exporting...", formHandle(), formText) { + displayCodeFeedback( + task, schemaActor(), request.copy( + messages = listOf( + response.code to ApiModel.Role.assistant, + "From the given code prototype, identify input out output data structures and generate Kotlin data classes to define this schema" to ApiModel.Role.user + ) + ) + ) { schemaCode -> + val command = actor.symbols.get("command")?.let { command -> + when (command) { + is String -> command.split(" ") + is List<*> -> command.map { it.toString() } + else -> throw IllegalArgumentException("Invalid command: $command") + } + } ?: listOf("bash") + val cwd = actor.symbols.get("workingDir")?.toString()?.let { File(it) } ?: File(".") + val env = actor.symbols.get("env")?.let { env -> (env as Map) } ?: mapOf() + val codePrefix = """ fun execute() : Pair { val command = "${command.joinToString(" ").escapeQuotedString}".split(" ") val cwd = java.io.File("${cwd.absolutePath.escapeQuotedString}") @@ -145,390 +145,390 @@ abstract class ShellToolAgent( } } """.trimIndent() - val messages = listOf( - "Shell Code: \n```${actor.language}\n${(response.code)/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, - ) + (lastResult?.let { - listOf( - "Example Output:\n\n```text\n${it/*.indent(" ")*/}\n```" to ApiModel.Role.assistant - ) - } ?: listOf()) + listOf( - "Schema: \n```kotlin\n${schemaCode/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, - "Implement a parsing method to convert the shell output to the requested data structure" to ApiModel.Role.user + val messages = listOf( + "Shell Code: \n```${actor.language}\n${(response.code)/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, + ) + (lastResult?.let { + listOf( + "Example Output:\n\n```text\n${it/*.indent(" ")*/}\n```" to ApiModel.Role.assistant + ) + } ?: listOf()) + listOf( + "Schema: \n```kotlin\n${schemaCode/*.indent(" ")*/}\n```" to ApiModel.Role.assistant, + "Implement a parsing method to convert the shell output to the requested data structure" to ApiModel.Role.user + ) + displayCodeFeedback( + task, parsedActor(), request.copy( + messages = messages, + codePrefix = codePrefix + ) + ) { parsedCode -> + displayCodeFeedback( + task, servletActor(), request.copy( + messages = listOf( + (codePrefix + "\n\n" + parsedCode) to ApiModel.Role.assistant, + "Reprocess this code prototype into a servlet. " + + "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user + ), + codePrefix = schemaCode + ) + ) { servletHandler -> + val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() + val toolsPrefix = "/tools" + var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> + openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) + } + task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui = ui)) + for (i in 0..5) { + try { + OpenAPIGenerator.main( + arrayOf( + "generate", + "-i", + File.createTempFile("openapi", ".json").apply { + writeText(JsonUtil.toJson(openAPI)) + deleteOnExit() + }.absolutePath, + "-g", + "html2", + "-o", + File( + dataStorage.getSessionDir(user, session), + "openapi/html2" + ).apply { mkdirs() }.absolutePath, + ) ) - displayCodeFeedback( - task, parsedActor(), request.copy( - messages = messages, - codePrefix = codePrefix - ) - ) { parsedCode -> - displayCodeFeedback( - task, servletActor(), request.copy( - messages = listOf( - (codePrefix + "\n\n" + parsedCode) to ApiModel.Role.assistant, - "Reprocess this code prototype into a servlet. " + - "The last line should instantiate the new servlet class and return it via the returnBuffer collection." to ApiModel.Role.user - ), - codePrefix = schemaCode - ) - ) { servletHandler -> - val servletImpl = (schemaCode + "\n\n" + servletHandler).sortCode() - val toolsPrefix = "/tools" - var openAPI = openAPIParsedActor().getParser(api).apply(servletImpl).let { openApi -> - openApi.copy(paths = openApi.paths?.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) - } - task.add(renderMarkdown("```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", ui = ui)) - for (i in 0..5) { - try { - OpenAPIGenerator.main( - arrayOf( - "generate", - "-i", - File.createTempFile("openapi", ".json").apply { - writeText(JsonUtil.toJson(openAPI)) - deleteOnExit() - }.absolutePath, - "-g", - "html2", - "-o", - File( - dataStorage.getSessionDir(user, session), - "openapi/html2" - ).apply { mkdirs() }.absolutePath, - ) - ) - task.add("Validated OpenAPI Descriptor - Documentation Saved") - break - } catch (e: SpecValidationException) { - val error = """ + task.add("Validated OpenAPI Descriptor - Documentation Saved") + break + } catch (e: SpecValidationException) { + val error = """ |${e.message} |${e.errors.joinToString("\n") { "ERROR:" + it.toString() }} |${e.warnings.joinToString("\n") { "WARN:" + it.toString() }} """.trimIndent() - task.hideable( - ui, - renderMarkdown( - "```\n${error.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", - ui = ui - ) - ) - openAPI = openAPIParsedActor().answer( - listOf( - servletImpl, - JsonUtil.toJson(openAPI), - error - ), api - ).obj.let { openApi -> - val paths = HashMap(openApi.paths) - openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) - } - task.hideable( - ui, - renderMarkdown( - "```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", - ui = ui - ) - ) - } - } - if (ApplicationServices.authorizationManager.isAuthorized( - ShellToolAgent.javaClass, - user, - AuthorizationInterface.OperationType.Admin - ) - ) { - /* - ToolServlet.addTool( - ToolServlet.Tool( - path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", - openApiDescription = openAPI, - interpreterString = getInterpreterString(), - servletCode = servletImpl - ) - ) - */ - } - buildTestPage(openAPI, servletImpl, task) - } + task.hideable( + ui, + renderMarkdown( + "```\n${error.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", + ui = ui + ) + ) + openAPI = openAPIParsedActor().answer( + listOf( + servletImpl, + JsonUtil.toJson(openAPI), + error + ), api + ).obj.let { openApi -> + val paths = HashMap(openApi.paths) + openApi.copy(paths = paths.mapKeys { toolsPrefix + it.key.removePrefix(toolsPrefix) }) } + task.hideable( + ui, + renderMarkdown( + "```json\n${JsonUtil.toJson(openAPI)/*.indent(" ")*/}\n```", + ui = ui + ) + ) + } } + if (ApplicationServices.authorizationManager.isAuthorized( + ShellToolAgent.javaClass, + user, + AuthorizationInterface.OperationType.Admin + ) + ) { + /* + ToolServlet.addTool( + ToolServlet.Tool( + path = openAPI.paths?.entries?.first()?.key?.removePrefix(toolsPrefix) ?: "unknown", + openApiDescription = openAPI, + interpreterString = getInterpreterString(), + servletCode = servletImpl + ) + ) + */ + } + buildTestPage(openAPI, servletImpl, task) + } } + } } + } - private fun openAPIParsedActor() = object : ParsedActor( + private fun openAPIParsedActor() = object : ParsedActor( // parserClass = OpenApiParser::class.java, - resultClass = OpenAPI::class.java, - model = model, - prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", - parsingModel = model, - ) { - override val describer: TypeDescriber - get() = object : AbbrevWhitelistYamlDescriber( - //"com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false - } - } + resultClass = OpenAPI::class.java, + model = model, + prompt = "You are a code documentation assistant. You will create the OpenAPI definition for a servlet handler written in kotlin", + parsingModel = model, + ) { + override val describer: TypeDescriber + get() = object : AbbrevWhitelistYamlDescriber( + //"com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false + } + } - private fun servletActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = actor.symbols + mapOf( - "returnBuffer" to ServletBuffer(), - "json" to JsonUtil, - "req" to Request(null, null), - "resp" to Response(null, null), - ), - describer = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", - "com.github.simiacryptus" - ) { - override fun describe( - rawType: Class, - stackMax: Int, - describedTypes: MutableSet - ): String = when (rawType) { - Request::class.java -> describe(HttpServletRequest::class.java) - Response::class.java -> describe(HttpServletResponse::class.java) - else -> super.describe(rawType, stackMax, describedTypes) - } - }, - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols + private fun servletActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = actor.symbols + mapOf( + "returnBuffer" to ServletBuffer(), + "json" to JsonUtil, + "req" to Request(null, null), + "resp" to Response(null, null), + ), + describer = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", + "com.github.simiacryptus" ) { - override val prompt: String - get() = super.prompt - } + override fun describe( + rawType: Class, + stackMax: Int, + describedTypes: MutableSet + ): String = when (rawType) { + Request::class.java -> describe(HttpServletRequest::class.java) + Response::class.java -> describe(HttpServletResponse::class.java) + else -> super.describe(rawType, stackMax, describedTypes) + } + }, + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - private fun schemaActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = mapOf(), - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + private fun schemaActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = mapOf(), + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - private fun parsedActor() = object : CodingActor( - interpreterClass = KotlinInterpreter::class, - symbols = mapOf(), - details = actor.details, - model = actor.model, - fallbackModel = actor.fallbackModel, - temperature = actor.temperature, - runtimeSymbols = actor.runtimeSymbols - ) { - override val prompt: String - get() = super.prompt - } + private fun parsedActor() = object : CodingActor( + interpreterClass = KotlinInterpreter::class, + symbols = mapOf(), + details = actor.details, + model = actor.model, + fallbackModel = actor.fallbackModel, + temperature = actor.temperature, + runtimeSymbols = actor.runtimeSymbols + ) { + override val prompt: String + get() = super.prompt + } - /** - * - * TODO: This method seems redundant. - * - * */ - private fun displayCodeFeedback( - task: SessionTask, - actor: CodingActor, - request: CodingActor.CodeRequest, - response: CodeResult = execWrap { actor.answer(request, api = api) }, - onComplete: (String) -> Unit - ) { - task.hideable( - ui, - renderMarkdown("```kotlin\n${response.code.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui = ui) - ) - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + /** + * + * TODO: This method seems redundant. + * + * */ + private fun displayCodeFeedback( + task: SessionTask, + actor: CodingActor, + request: CodingActor.CodeRequest, + response: CodeResult = execWrap { actor.answer(request, api = api) }, + onComplete: (String) -> Unit + ) { + task.hideable( + ui, + renderMarkdown("```kotlin\n${response.code.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui = ui) + ) + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """ |
|${ - super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button") { - super.responseAction(task, "Accepted...", formHandle!!, formText) { - onComplete(response.code) - } - } - } + super.ui.hrefLink("\uD83D\uDC4D", "href-link play-button") { + super.responseAction(task, "Accepted...", formHandle!!, formText) { + onComplete(response.code) + } + } + } |${ - if (!super.canPlay) "" else - ui.hrefLink("▶", "href-link play-button") { - execute(ui.newTask(), response) - } - } + if (!super.canPlay) "" else + ui.hrefLink("▶", "href-link play-button") { + execute(ui.newTask(), response) + } + } |${ - super.ui.hrefLink("♻", "href-link regen-button") { - super.responseAction(task, "Regenerating...", formHandle!!, formText) { - //val task = super.ui.newTask() - val codeRequest = - request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) - try { - val lastUserMessage = - codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as ChatClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) - } - super.displayCode(task, codeResponse) - displayCodeFeedback( - task, - actor, - super.append(codeRequest, codeResponse), - codeResponse, - onComplete - ) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) - } - } - } + super.ui.hrefLink("♻", "href-link regen-button") { + super.responseAction(task, "Regenerating...", formHandle!!, formText) { + //val task = super.ui.newTask() + val codeRequest = + request.copy(messages = request.messages.dropLastWhile { it.second == ApiModel.Role.assistant }) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as ChatClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) + } + super.displayCode(task, codeResponse) + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) } + } + } + } |
|${ - super.ui.textInput { feedback -> - super.responseAction(task, "Revising...", formHandle!!, formText) { - //val task = super.ui.newTask() - try { - task.echo(renderMarkdown(feedback, ui = ui)) - val codeRequest = CodingActor.CodeRequest( - messages = request.messages + - listOf( - response.code to ApiModel.Role.assistant, - feedback to ApiModel.Role.user, - ).filter { it.first.isNotBlank() }.map { it.first to it.second } - ) - try { - val lastUserMessage = - codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() - val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { - actor.CodeResultImpl( - messages = actor.chatMessages(codeRequest), - input = codeRequest, - api = super.api as ChatClient, - givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") - ) - } else { - actor.answer(codeRequest, api = super.api) - } - displayCodeFeedback( - task, - actor, - super.append(codeRequest, codeResponse), - codeResponse, - onComplete - ) - } catch (e: Throwable) { - log.warn("Error", e) - val error = task.error(super.ui, e) - var regenButton: StringBuilder? = null - regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { - regenButton?.clear() - val header = task.header("Regenerating...") - super.displayCode(task, codeRequest) - header?.clear() - error?.clear() - task.complete() - }) - } - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) - } - } + super.ui.textInput { feedback -> + super.responseAction(task, "Revising...", formHandle!!, formText) { + //val task = super.ui.newTask() + try { + task.echo(renderMarkdown(feedback, ui = ui)) + val codeRequest = CodingActor.CodeRequest( + messages = request.messages + + listOf( + response.code to ApiModel.Role.assistant, + feedback to ApiModel.Role.user, + ).filter { it.first.isNotBlank() }.map { it.first to it.second } + ) + try { + val lastUserMessage = + codeRequest.messages.last { it.second == ApiModel.Role.user }.first.trim() + val codeResponse: CodeResult = if (lastUserMessage.startsWith("```")) { + actor.CodeResultImpl( + messages = actor.chatMessages(codeRequest), + input = codeRequest, + api = super.api as ChatClient, + givenCode = lastUserMessage.removePrefix("```").removeSuffix("```") + ) + } else { + actor.answer(codeRequest, api = super.api) } + displayCodeFeedback( + task, + actor, + super.append(codeRequest, codeResponse), + codeResponse, + onComplete + ) + } catch (e: Throwable) { + log.warn("Error", e) + val error = task.error(super.ui, e) + var regenButton: StringBuilder? = null + regenButton = task.complete(super.ui.hrefLink("♻", "href-link regen-button") { + regenButton?.clear() + val header = task.header("Regenerating...") + super.displayCode(task, codeRequest) + header?.clear() + error?.clear() + task.complete() + }) + } + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) } + } + } + } """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - class ServletBuffer : ArrayList() + class ServletBuffer : ArrayList() - private fun buildTestPage( - openAPI: OpenAPI, - servletImpl: String, - task: SessionTask - ) { - var testPage = SimpleActor( - prompt = "Given the definition for a servlet handler, create a test page that can be used to test the servlet", - model = model, - ).answer( - listOf( - JsonUtil.toJson(openAPI), - servletImpl - ), api + private fun buildTestPage( + openAPI: OpenAPI, + servletImpl: String, + task: SessionTask + ) { + var testPage = SimpleActor( + prompt = "Given the definition for a servlet handler, create a test page that can be used to test the servlet", + model = model, + ).answer( + listOf( + JsonUtil.toJson(openAPI), + servletImpl + ), api + ) + // if ```html unwrap + if (testPage.contains("```html")) testPage = testPage.substringAfter("```html").substringBefore("```") + task.add(renderMarkdown("```html\n${testPage.let { /*escapeHtml4*/(it)/*.indent(" ")*/ }}\n```", ui = ui)) + task.complete( + "Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" - ) - } + }'>Test Page for ${openAPI.paths?.entries?.first()?.key ?: "unknown"} Saved" + ) + } - abstract fun getInterpreterString(): String + abstract fun getInterpreterString(): String - private fun answer( - actor: CodingActor, - request: CodingActor.CodeRequest, - task: SessionTask = ui.newTask(), - feedback: Boolean = true, - ): CodeResult { - val response = actor.answer(request, api = api) - if (feedback) displayCodeAndFeedback(task, request, response) - else displayCode(task, response) - return response - } + private fun answer( + actor: CodingActor, + request: CodingActor.CodeRequest, + task: SessionTask = ui.newTask(), + feedback: Boolean = true, + ): CodeResult { + val response = actor.answer(request, api = api) + if (feedback) displayCodeAndFeedback(task, request, response) + else displayCode(task, response) + return response + } - companion object { - val log = LoggerFactory.getLogger(ShellToolAgent::class.java) - fun execWrap(fn: () -> T): T { - val classLoader = Thread.currentThread().contextClassLoader - val prevCL = KotlinInterpreter.classLoader - KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader - return try { - WebAppClassLoader.runWithServerClassAccess { - require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) - require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) - // com.simiacryptus.jopenai.OpenAIClient - require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) - require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) - fn() - } - } finally { - KotlinInterpreter.classLoader = prevCL - } + companion object { + val log = LoggerFactory.getLogger(ShellToolAgent::class.java) + fun execWrap(fn: () -> T): T { + val classLoader = Thread.currentThread().contextClassLoader + val prevCL = KotlinInterpreter.classLoader + KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader + return try { + WebAppClassLoader.runWithServerClassAccess { + require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) + require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) + // com.simiacryptus.jopenai.OpenAIClient + require(null != classLoader.loadClass("com.simiacryptus.jopenai.OpenAIClient")) + require(null != classLoader.loadClass("com.simiacryptus.jopenai.API")) + fn() } + } finally { + KotlinInterpreter.classLoader = prevCL + } } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/AutoPlanChatApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/AutoPlanChatApp.kt index 0dd2b725..0978feae 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/AutoPlanChatApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/AutoPlanChatApp.kt @@ -26,388 +26,390 @@ import java.util.concurrent.Future import java.util.concurrent.atomic.AtomicReference open class AutoPlanChatApp( - applicationName: String = "Auto Plan Chat App", - path: String = "/autoPlanChat", - planSettings: PlanSettings, - model: ChatModel, - parsingModel: ChatModel, - domainName: String = "localhost", - showMenubar: Boolean = true, - api: API? = null, - api2: OpenAIClient, - val maxTaskHistoryChars: Int = 20000, - val maxTasksPerIteration: Int = 3, - val maxIterations: Int = 100 + applicationName: String = "Auto Plan Chat App", + path: String = "/autoPlanChat", + planSettings: PlanSettings, + model: ChatModel, + parsingModel: ChatModel, + domainName: String = "localhost", + showMenubar: Boolean = true, + api: API? = null, + api2: OpenAIClient, + val maxTaskHistoryChars: Int = 20000, + val maxTasksPerIteration: Int = 3, + val maxIterations: Int = 100 ) : PlanChatApp( - applicationName = applicationName, - path = path, - planSettings = planSettings, - model = model, - parsingModel = parsingModel, - domainName = domainName, - showMenubar = showMenubar, - api = api, - api2 = api2 + applicationName = applicationName, + path = path, + planSettings = planSettings, + model = model, + parsingModel = parsingModel, + domainName = domainName, + showMenubar = showMenubar, + api = api, + api2 = api2 ) { - override val stickyInput = true - override val singleInput = false - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(AutoPlanChatApp::class.java) - } + override val stickyInput = true + override val singleInput = false - data class ThinkingStatus( - var initialPrompt: String? = null, - val goals: Goals? = null, - val knowledge: Knowledge? = null, - val executionContext: ExecutionContext? = null - ) + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(AutoPlanChatApp::class.java) + } - data class Goals( - val shortTerm: MutableList? = null, - val longTerm: MutableList? = null - ) + data class ThinkingStatus( + var initialPrompt: String? = null, + val goals: Goals? = null, + val knowledge: Knowledge? = null, + val executionContext: ExecutionContext? = null + ) - data class Knowledge( - val facts: MutableList? = null, - val hypotheses: MutableList? = null, - val openQuestions: MutableList? = null - ) + data class Goals( + val shortTerm: MutableList? = null, + val longTerm: MutableList? = null + ) - data class ExecutionContext( - val completedTasks: MutableList? = null, - val currentTask: CurrentTask? = null, - val nextSteps: MutableList? = null - ) + data class Knowledge( + val facts: MutableList? = null, + val hypotheses: MutableList? = null, + val openQuestions: MutableList? = null + ) - data class CurrentTask( - val taskId: String? = null, - val description: String? = null - ) + data class ExecutionContext( + val completedTasks: MutableList? = null, + val currentTask: CurrentTask? = null, + val nextSteps: MutableList? = null + ) - data class ExecutionRecord( - val time: Date? = Date(), - val task: PlanTaskBase? = null, - val result: String? = null - ) + data class CurrentTask( + val taskId: String? = null, + val description: String? = null + ) - data class Tasks( - val tasks: MutableList? = null - ) + data class ExecutionRecord( + val time: Date? = Date(), + val task: PlanTaskBase? = null, + val result: String? = null + ) - private val currentUserMessage = AtomicReference(null) - private var isRunning = false - val executionRecords = mutableListOf() + data class Tasks( + val tasks: MutableList? = null + ) - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - try { - log.info("Received user message: $userMessage") - if (!isRunning) { - isRunning = true - log.info("Starting new auto plan chat") - startAutoPlanChat(session, user, userMessage, ui, api) - } else { - log.info("Injecting user message into ongoing chat") - val userMessageTask = ui.newTask() - userMessageTask.echo(renderMarkdown("User: $userMessage", ui = ui)) - currentUserMessage.set(userMessage) - } - } catch (e: Exception) { - log.error("Error processing user message", e) - ui.newTask().add(renderMarkdown("An error occurred while processing your message: ${e.message}", ui = ui)) - } - } + private val currentUserMessage = AtomicReference(null) + private var isRunning = false + val executionRecords = mutableListOf() - protected open fun startAutoPlanChat( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - val thinkingStatus = AtomicReference(null); - val task = ui.newTask(true) - val api = (api as ChatClient).getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - - val tabbedDisplay = TabbedDisplay(task) - val executor = ui.socketManager!!.pool - var continueLoop = true - executor.execute { - try { - ui.newTask(false).let { task -> - tabbedDisplay["Controls"] = task.placeholder - lateinit var stopLink: StringBuilder - stopLink = task.add(ui.hrefLink("Stop") { - continueLoop = false - executor.shutdown() - stopLink.set("Stopped") - task.complete() - })!! - } - tabbedDisplay.update() - task.complete() - - val initialPromptTask = ui.newTask(false) - initialPromptTask.add(renderMarkdown("Starting Auto Plan Chat for prompt: $userMessage")) - tabbedDisplay["Initial Prompt"] = initialPromptTask.placeholder - val planSettings = getSettings(session, user, PlanSettings::class.java) ?: planSettings.copy(allowBlocking = false) - api.budget = planSettings.budget - val coordinator = PlanCoordinator( - user = user, - session = session, - dataStorage = dataStorage, - ui = ui, - root = planSettings.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), - planSettings = planSettings - ) - val initialStatus = initThinking(planSettings, userMessage) - initialStatus.initialPrompt = userMessage - thinkingStatus.set(initialStatus) - initialPromptTask.complete(renderMarkdown("Initial Thinking Status:\n${formatThinkingStatus(thinkingStatus.get()!!)}")) - - var iteration = 0 - while (iteration++ < maxIterations && continueLoop) { - task.complete() - val task = ui.newTask(false).apply { tabbedDisplay["Iteration $iteration"] = placeholder } - val api = api.getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - val tabbedDisplay = TabbedDisplay(task) - ui.newTask(false).apply { - tabbedDisplay["Inputs"] = placeholder - header("Project Info") - contextData().forEach { add(renderMarkdown(it)) } - header("Evaluation Records") - formatEvalRecords().forEach { add(renderMarkdown(it)) } - header("Current Thinking Status") - formatThinkingStatus(thinkingStatus.get()!!).let { add(renderMarkdown(it)) } - } - val nextTask = try { - getNextTask(api, planSettings, coordinator, userMessage, thinkingStatus.get()) - } catch (e: Exception) { - log.error("Error choosing next task", e) - tabbedDisplay["Errors"]?.append(renderMarkdown("Error choosing next task: ${e.message}")) - break - } - if (nextTask?.isEmpty() != false) { - task.add(renderMarkdown("No more tasks to execute. Finishing Auto Plan Chat.")) - break - } + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + try { + log.info("Received user message: $userMessage") + if (!isRunning) { + isRunning = true + log.info("Starting new auto plan chat") + startAutoPlanChat(session, user, userMessage, ui, api) + } else { + log.info("Injecting user message into ongoing chat") + val userMessageTask = ui.newTask() + userMessageTask.echo(renderMarkdown("User: $userMessage", ui = ui)) + currentUserMessage.set(userMessage) + } + } catch (e: Exception) { + log.error("Error processing user message", e) + ui.newTask().add(renderMarkdown("An error occurred while processing your message: ${e.message}", ui = ui)) + } + } - val taskResults = mutableListOf>>() - for ((index, currentTask) in nextTask.withIndex()) { - val currentTaskId = "task_${index + 1}" - val taskExecutionTask = ui.newTask(false) - taskExecutionTask.add( - renderMarkdown( - "Executing task: `$currentTaskId` - ${currentTask.task_description}\n```json\n${ - JsonUtil.toJson(currentTask) - }\n```" - ) - ) - tabbedDisplay["Task Execution $currentTaskId"] = taskExecutionTask.placeholder - val future = executor.submit { - try { - runTask(api, api2, task, coordinator, currentTask, currentTaskId, userMessage, taskExecutionTask, thinkingStatus.get()) - } catch (e: Exception) { - taskExecutionTask.error(ui, e) - log.error("Error executing task", e) - "Error executing task: ${e.message}" - } - } - taskResults.add(Pair(currentTask, future)) - } - val completedTasks = taskResults.map { (task, future) -> - val result = future.get() - ExecutionRecord( - time = Date(), - task = task, - result = result - ) - } - executionRecords.addAll(completedTasks) + protected open fun startAutoPlanChat( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + val thinkingStatus = AtomicReference(null); + val task = ui.newTask(true) + val api = (api as ChatClient).getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } - val thinkingStatusTask = ui.newTask(false).apply { tabbedDisplay["Thinking Status"] = placeholder } - thinkingStatus.set( - updateThinking(api, planSettings, thinkingStatus.get(), completedTasks) - ) - thinkingStatusTask.complete(renderMarkdown("Updated Thinking Status:\n${formatThinkingStatus(thinkingStatus.get()!!)}")) - } - task.complete("Auto Plan Chat completed.") - } catch (e: Throwable) { - task.error(ui, e) - log.error("Error in startAutoPlanChat", e) - } finally { - val summaryTask = ui.newTask(false).apply { tabbedDisplay["Summary"] = placeholder } - summaryTask.add( - renderMarkdown( - "Auto Plan Chat completed. Final thinking status:\n${thinkingStatus.get()?.let { - formatThinkingStatus(it) - } ?: "null" - }")) - task.complete() - } + val tabbedDisplay = TabbedDisplay(task) + val executor = ui.socketManager!!.pool + var continueLoop = true + executor.execute { + try { + ui.newTask(false).let { task -> + tabbedDisplay["Controls"] = task.placeholder + lateinit var stopLink: StringBuilder + stopLink = task.add(ui.hrefLink("Stop") { + continueLoop = false + executor.shutdown() + stopLink.set("Stopped") + task.complete() + })!! } + tabbedDisplay.update() + task.complete() - } + val initialPromptTask = ui.newTask(false) + initialPromptTask.add(renderMarkdown("Starting Auto Plan Chat for prompt: $userMessage")) + tabbedDisplay["Initial Prompt"] = initialPromptTask.placeholder + val planSettings = getSettings(session, user, PlanSettings::class.java) ?: planSettings.copy(allowBlocking = false) + api.budget = planSettings.budget + val coordinator = PlanCoordinator( + user = user, + session = session, + dataStorage = dataStorage, + ui = ui, + root = planSettings.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), + planSettings = planSettings + ) + val initialStatus = initThinking(planSettings, userMessage) + initialStatus.initialPrompt = userMessage + thinkingStatus.set(initialStatus) + initialPromptTask.complete(renderMarkdown("Initial Thinking Status:\n${formatThinkingStatus(thinkingStatus.get()!!)}")) - protected open fun runTask( - api: ChatClient, - api2: OpenAIClient, - task: SessionTask, - coordinator: PlanCoordinator, - currentTask: PlanTaskBase, - currentTaskId: String, - userMessage: String, - taskExecutionTask: SessionTask, - thinkingStatus: ThinkingStatus? - ): String { - val api = api.getChildClient().apply { + var iteration = 0 + while (iteration++ < maxIterations && continueLoop) { + task.complete() + val task = ui.newTask(false).apply { tabbedDisplay["Iteration $iteration"] = placeholder } + val api = api.getChildClient().apply { val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") } + } + val tabbedDisplay = TabbedDisplay(task) + ui.newTask(false).apply { + tabbedDisplay["Inputs"] = placeholder + header("Project Info") + contextData().forEach { add(renderMarkdown(it)) } + header("Evaluation Records") + formatEvalRecords().forEach { add(renderMarkdown(it)) } + header("Current Thinking Status") + formatThinkingStatus(thinkingStatus.get()!!).let { add(renderMarkdown(it)) } + } + val nextTask = try { + getNextTask(api, planSettings, coordinator, userMessage, thinkingStatus.get()) + } catch (e: Exception) { + log.error("Error choosing next task", e) + tabbedDisplay["Errors"]?.append(renderMarkdown("Error choosing next task: ${e.message}")) + break + } + if (nextTask?.isEmpty() != false) { + task.add(renderMarkdown("No more tasks to execute. Finishing Auto Plan Chat.")) + break + } + + val taskResults = mutableListOf>>() + for ((index, currentTask) in nextTask.withIndex()) { + val currentTaskId = "task_${index + 1}" + val taskExecutionTask = ui.newTask(false) + taskExecutionTask.add( + renderMarkdown( + "Executing task: `$currentTaskId` - ${currentTask.task_description}\n```json\n${ + JsonUtil.toJson(currentTask) + }\n```" + ) + ) + tabbedDisplay["Task Execution $currentTaskId"] = taskExecutionTask.placeholder + val future = executor.submit { + try { + runTask(api, api2, task, coordinator, currentTask, currentTaskId, userMessage, taskExecutionTask, thinkingStatus.get()) + } catch (e: Exception) { + taskExecutionTask.error(ui, e) + log.error("Error executing task", e) + "Error executing task: ${e.message}" + } + } + taskResults.add(Pair(currentTask, future)) + } + val completedTasks = taskResults.map { (task, future) -> + val result = future.get() + ExecutionRecord( + time = Date(), + task = task, + result = result + ) + } + executionRecords.addAll(completedTasks) + + val thinkingStatusTask = ui.newTask(false).apply { tabbedDisplay["Thinking Status"] = placeholder } + thinkingStatus.set( + updateThinking(api, planSettings, thinkingStatus.get(), completedTasks) + ) + thinkingStatusTask.complete(renderMarkdown("Updated Thinking Status:\n${formatThinkingStatus(thinkingStatus.get()!!)}")) } - val taskImpl = TaskType.Companion.getImpl(coordinator.planSettings, currentTask) - val result = StringBuilder() - taskImpl.run( - agent = coordinator.copy( - planSettings = coordinator.planSettings.copy( - taskSettings = coordinator.planSettings.taskSettings.toList().toTypedArray().toMap().toMutableMap().apply { - this["TaskPlanning"] = TaskSettings(enabled = false, model = null) - } - ) - ), - messages = listOf( - userMessage, - "Current thinking status:\n${formatThinkingStatus(thinkingStatus!!)}" - ) + formatEvalRecords(), - task = taskExecutionTask, - api = api, - resultFn ={ result.append(it) }, - api2 = api2, - planSettings = planSettings, - ) - return result.toString() + task.complete("Auto Plan Chat completed.") + } catch (e: Throwable) { + task.error(ui, e) + log.error("Error in startAutoPlanChat", e) + } finally { + val summaryTask = ui.newTask(false).apply { tabbedDisplay["Summary"] = placeholder } + summaryTask.add( + renderMarkdown( + "Auto Plan Chat completed. Final thinking status:\n${ + thinkingStatus.get()?.let { + formatThinkingStatus(it) + } ?: "null" + }")) + task.complete() + } + } + + } + + protected open fun runTask( + api: ChatClient, + api2: OpenAIClient, + task: SessionTask, + coordinator: PlanCoordinator, + currentTask: PlanTaskBase, + currentTaskId: String, + userMessage: String, + taskExecutionTask: SessionTask, + thinkingStatus: ThinkingStatus? + ): String { + val api = api.getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } } + val taskImpl = TaskType.Companion.getImpl(coordinator.planSettings, currentTask) + val result = StringBuilder() + taskImpl.run( + agent = coordinator.copy( + planSettings = coordinator.planSettings.copy( + taskSettings = coordinator.planSettings.taskSettings.toList().toTypedArray().toMap().toMutableMap().apply { + this["TaskPlanning"] = TaskSettings(enabled = false, model = null) + } + ) + ), + messages = listOf( + userMessage, + "Current thinking status:\n${formatThinkingStatus(thinkingStatus!!)}" + ) + formatEvalRecords(), + task = taskExecutionTask, + api = api, + resultFn = { result.append(it) }, + api2 = api2, + planSettings = planSettings, + ) + return result.toString() + } - protected open fun getNextTask( - api: ChatClient, - planSettings: PlanSettings, - coordinator: PlanCoordinator, - userMessage: String, - thinkingStatus: ThinkingStatus? - ): List? { - val describer1 = planSettings.describer() - val tasks = ParsedActor( - name = "SingleTaskChooser", - resultClass = Tasks::class.java, - exampleInstance = Tasks( - listOf( - FileModificationTaskData( - task_description = "Modify the file 'example.txt' to include the given input." - ) - ).toMutableList() - ), - prompt = """ + protected open fun getNextTask( + api: ChatClient, + planSettings: PlanSettings, + coordinator: PlanCoordinator, + userMessage: String, + thinkingStatus: ThinkingStatus? + ): List? { + val describer1 = planSettings.describer() + val tasks = ParsedActor( + name = "SingleTaskChooser", + resultClass = Tasks::class.java, + exampleInstance = Tasks( + listOf( + FileModificationTaskData( + task_description = "Modify the file 'example.txt' to include the given input." + ) + ).toMutableList() + ), + prompt = """ Given the following input, choose up to ${maxTasksPerIteration} tasks to execute. Do not create a full plan, just select the most appropriate task types for the given input. Available task types: ${ - TaskType.Companion.getAvailableTaskTypes(coordinator.planSettings).joinToString>("\n") { taskType -> - "* ${TaskType.Companion.getImpl(coordinator.planSettings, taskType).promptSegment()}" - } - } + TaskType.Companion.getAvailableTaskTypes(coordinator.planSettings).joinToString>("\n") { taskType -> + "* ${TaskType.Companion.getImpl(coordinator.planSettings, taskType).promptSegment()}" + } + } Choose the most suitable task types and provide details of how they should be executed. """.trimIndent(), - model = coordinator.planSettings.defaultModel, - parsingModel = coordinator.planSettings.parsingModel, - temperature = coordinator.planSettings.temperature, - describer = describer1, - parserPrompt = """ + model = coordinator.planSettings.defaultModel, + parsingModel = coordinator.planSettings.parsingModel, + temperature = coordinator.planSettings.temperature, + describer = describer1, + parserPrompt = """ Task Subtype Schema: ${ - TaskType.Companion.getAvailableTaskTypes(coordinator.planSettings).joinToString>("\n\n") { taskType -> - """ + TaskType.Companion.getAvailableTaskTypes(coordinator.planSettings).joinToString>("\n\n") { taskType -> + """ ${taskType.name}: ${describer1.describe(taskType.taskDataClass).replace("\n", "\n ")} """.trim() - } - } + } + } """.trimIndent() - ).answer( - listOf(userMessage) + contextData() + - listOf( - """ + ).answer( + listOf(userMessage) + contextData() + + listOf( + """ Current thinking status: ${formatThinkingStatus(thinkingStatus!!)} Please choose the next single task to execute based on the current status. If there are no tasks to execute, return {}. """.trimIndent() - ) - + formatEvalRecords(), api - ).obj.tasks?.map { task -> - task to (if (task.task_type == null) { - null - } else { - TaskType.Companion.getImpl(coordinator.planSettings, task) - })?.planTask - } - if (tasks.isNullOrEmpty()) { - log.info("No tasks selected from: ${tasks?.map { it.first }}") - return null - } else if (tasks.mapNotNull { it.second }.isEmpty()) { - log.warn("No tasks selected from: ${tasks.map { it.first }}") - return null - } else { - return tasks.mapNotNull { it.second }.take(maxTasksPerIteration) - } + ) + + formatEvalRecords(), api + ).obj.tasks?.map { task -> + task to (if (task.task_type == null) { + null + } else { + TaskType.Companion.getImpl(coordinator.planSettings, task) + })?.planTask } + if (tasks.isNullOrEmpty()) { + log.info("No tasks selected from: ${tasks?.map { it.first }}") + return null + } else if (tasks.mapNotNull { it.second }.isEmpty()) { + log.warn("No tasks selected from: ${tasks.map { it.first }}") + return null + } else { + return tasks.mapNotNull { it.second }.take(maxTasksPerIteration) + } + } - protected open fun updateThinking( - api: ChatClient, - planSettings: PlanSettings, - thinkingStatus: ThinkingStatus?, - completedTasks: List - ): ThinkingStatus = ParsedActor( - name = "UpdateQuestionsActor", - resultClass = ThinkingStatus::class.java, - exampleInstance = ThinkingStatus( - initialPrompt = "Example prompt", - goals = Goals( - shortTerm = mutableListOf("Analyze task results"), - longTerm = mutableListOf("Complete the user's request") - ), - knowledge = Knowledge( - facts = mutableListOf( - "Initial Context: User's request received", - "Task 1 Result: Analyzed user's request" - ), - openQuestions = mutableListOf("What is the next task?", "Are there any remaining tasks?") - ), - executionContext = ExecutionContext( - completedTasks = mutableListOf("task_1"), - nextSteps = mutableListOf("Analyze task results", "Determine next action"), - ) + protected open fun updateThinking( + api: ChatClient, + planSettings: PlanSettings, + thinkingStatus: ThinkingStatus?, + completedTasks: List + ): ThinkingStatus = ParsedActor( + name = "UpdateQuestionsActor", + resultClass = ThinkingStatus::class.java, + exampleInstance = ThinkingStatus( + initialPrompt = "Example prompt", + goals = Goals( + shortTerm = mutableListOf("Analyze task results"), + longTerm = mutableListOf("Complete the user's request") + ), + knowledge = Knowledge( + facts = mutableListOf( + "Initial Context: User's request received", + "Task 1 Result: Analyzed user's request" ), - prompt = """ + openQuestions = mutableListOf("What is the next task?", "Are there any remaining tasks?") + ), + executionContext = ExecutionContext( + completedTasks = mutableListOf("task_1"), + nextSteps = mutableListOf("Analyze task results", "Determine next action"), + ) + ), + prompt = """ Given the current thinking status, the last completed task, and its result, update the open questions to guide the next steps of the planning process. Consider what information is still needed and what new questions arise from the task result. @@ -416,50 +418,50 @@ open class AutoPlanChatApp( Update the estimated time remaining and adjust the confidence level based on progress. Reassess challenges, available resources, and alternative approaches. """.trimIndent(), - model = planSettings.defaultModel, - parsingModel = planSettings.parsingModel, - temperature = planSettings.temperature, - describer = planSettings.describer() - ).answer( - listOf("Current thinking status: ${formatThinkingStatus(thinkingStatus!!)}") + contextData() + - completedTasks.flatMap { record -> - listOf( - "Completed task: ${record.task?.task_description}", - "Task result: ${record.result}" - ) - } + (currentUserMessage.get()?.let> { listOf("User message: $it") } ?: listOf()), - api - ).obj.apply { - this@AutoPlanChatApp.currentUserMessage.set(null) - knowledge?.facts?.apply { - this.addAll(completedTasks.mapIndexed { index, (task, result) -> - "Task ${(executionContext?.completedTasks?.size ?: 0) + index + 1} Result: $result" - }) - } + model = planSettings.defaultModel, + parsingModel = planSettings.parsingModel, + temperature = planSettings.temperature, + describer = planSettings.describer() + ).answer( + listOf("Current thinking status: ${formatThinkingStatus(thinkingStatus!!)}") + contextData() + + completedTasks.flatMap { record -> + listOf( + "Completed task: ${record.task?.task_description}", + "Task result: ${record.result}" + ) + } + (currentUserMessage.get()?.let> { listOf("User message: $it") } ?: listOf()), + api + ).obj.apply { + this@AutoPlanChatApp.currentUserMessage.set(null) + knowledge?.facts?.apply { + this.addAll(completedTasks.mapIndexed { index, (task, result) -> + "Task ${(executionContext?.completedTasks?.size ?: 0) + index + 1} Result: $result" + }) } + } - protected open fun initThinking( - planSettings: PlanSettings, - userMessage: String - ): ThinkingStatus { - val initialStatus = ParsedActor( - name = "ThinkingStatusInitializer", - resultClass = ThinkingStatus::class.java, - exampleInstance = ThinkingStatus( - initialPrompt = "Example prompt", - goals = Goals( - shortTerm = mutableListOf("Understand the user's request"), - longTerm = mutableListOf("Complete the user's task") - ), - knowledge = Knowledge( - facts = mutableListOf("Initial Context: User's request received"), - openQuestions = mutableListOf("What is the first task?") - ), - executionContext = ExecutionContext( - nextSteps = mutableListOf("Analyze the initial prompt", "Identify key objectives"), - ) - ), - prompt = """ + protected open fun initThinking( + planSettings: PlanSettings, + userMessage: String + ): ThinkingStatus { + val initialStatus = ParsedActor( + name = "ThinkingStatusInitializer", + resultClass = ThinkingStatus::class.java, + exampleInstance = ThinkingStatus( + initialPrompt = "Example prompt", + goals = Goals( + shortTerm = mutableListOf("Understand the user's request"), + longTerm = mutableListOf("Complete the user's task") + ), + knowledge = Knowledge( + facts = mutableListOf("Initial Context: User's request received"), + openQuestions = mutableListOf("What is the first task?") + ), + executionContext = ExecutionContext( + nextSteps = mutableListOf("Analyze the initial prompt", "Identify key objectives"), + ) + ), + prompt = """ Given the user's initial prompt, initialize the thinking status for an AI assistant. Set short-term and long-term goals. Generate relevant open questions and hypotheses to guide the planning process. @@ -467,19 +469,19 @@ open class AutoPlanChatApp( Set up the execution context with initial next steps, progress (0-100), estimated time remaining, and confidence level (0-100). Identify potential challenges and available resources. """.trimIndent(), - model = planSettings.defaultModel, - parsingModel = planSettings.parsingModel, - temperature = planSettings.temperature, - describer = planSettings.describer() - ).answer(listOf(userMessage) + contextData(), this.api!!).obj - return initialStatus - } + model = planSettings.defaultModel, + parsingModel = planSettings.parsingModel, + temperature = planSettings.temperature, + describer = planSettings.describer() + ).answer(listOf(userMessage) + contextData(), this.api!!).obj + return initialStatus + } - protected open fun formatEvalRecords(maxTotalLength: Int = maxTaskHistoryChars): List { - var currentLength = 0 - val formattedRecords = mutableListOf() - for (record in executionRecords.reversed()) { - val formattedRecord = """ + protected open fun formatEvalRecords(maxTotalLength: Int = maxTaskHistoryChars): List { + var currentLength = 0 + val formattedRecords = mutableListOf() + for (record in executionRecords.reversed()) { + val formattedRecord = """ # Task ${executionRecords.indexOf(record) + 1} ## Task: @@ -488,33 +490,35 @@ ${JsonUtil.toJson(record.task!!)} ``` ## Result: -${record.result?.let { - // Add 2 levels of header level to each header - it.split("\n").joinToString("\n") { line -> - if (line.startsWith("#")) { - "##$line" - } else { - line - } - } -}} -""" - if (currentLength + formattedRecord.length > maxTotalLength) { - formattedRecords.add("... (earlier records truncated)") - break +${ + record.result?.let { + // Add 2 levels of header level to each header + it.split("\n").joinToString("\n") { line -> + if (line.startsWith("#")) { + "##$line" + } else { + line } - formattedRecords.add(0, formattedRecord) - currentLength += formattedRecord.length + } } - return formattedRecords + } +""" + if (currentLength + formattedRecord.length > maxTotalLength) { + formattedRecords.add("... (earlier records truncated)") + break + } + formattedRecords.add(0, formattedRecord) + currentLength += formattedRecord.length } + return formattedRecords + } - protected open fun formatThinkingStatus(thinkingStatus: ThinkingStatus) = """ + protected open fun formatThinkingStatus(thinkingStatus: ThinkingStatus) = """ ```json ${JsonUtil.toJson(thinkingStatus)} ``` """ - protected open fun contextData(): List = emptyList() + protected open fun contextData(): List = emptyList() } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CmdPatchApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CmdPatchApp.kt index a319ea74..6b15f918 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CmdPatchApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CmdPatchApp.kt @@ -14,151 +14,151 @@ import java.nio.file.Path import java.util.concurrent.TimeUnit class CmdPatchApp( - root: Path, - session: Session, - settings: Settings, - api: ChatClient, - val files: Array?, - model: ChatModel + root: Path, + session: Session, + settings: Settings, + api: ChatClient, + val files: Array?, + model: ChatModel ) : PatchApp(root.toFile(), session, settings, api, model) { - companion object { - private val log = LoggerFactory.getLogger(CmdPatchApp::class.java) + companion object { + private val log = LoggerFactory.getLogger(CmdPatchApp::class.java) - val String.htmlEscape: String - get() = this.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace("\"", """) - .replace("'", "'") + val String.htmlEscape: String + get() = this.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace("\"", """) + .replace("'", "'") - fun truncate(output: String, kb: Int = 32): String { - var returnVal = output - if (returnVal.length > 1024 * 2 * kb) { - returnVal = returnVal.substring(0, 1024 * kb) + - "\n\n... Output truncated ...\n\n" + - returnVal.substring(returnVal.length - 1024 * kb) - } - return returnVal - } + fun truncate(output: String, kb: Int = 32): String { + var returnVal = output + if (returnVal.length > 1024 * 2 * kb) { + returnVal = returnVal.substring(0, 1024 * kb) + + "\n\n... Output truncated ...\n\n" + + returnVal.substring(returnVal.length - 1024 * kb) + } + return returnVal + } - } + } - private fun getFiles( - virtualFiles: Array? - ): MutableSet { - val codeFiles = mutableSetOf() // Set to avoid duplicates - virtualFiles?.forEach { file -> - if (file.isDirectory) { - if (file.name.startsWith(".")) return@forEach - if (FileValidationUtils.isGitignore(file.toPath())) return@forEach - codeFiles.addAll(getFiles(file.listFiles())) - } else { - codeFiles.add((file.toPath())) - } - } - return codeFiles + private fun getFiles( + virtualFiles: Array? + ): MutableSet { + val codeFiles = mutableSetOf() // Set to avoid duplicates + virtualFiles?.forEach { file -> + if (file.isDirectory) { + if (file.name.startsWith(".")) return@forEach + if (FileValidationUtils.isGitignore(file.toPath())) return@forEach + codeFiles.addAll(getFiles(file.listFiles())) + } else { + codeFiles.add((file.toPath())) + } } + return codeFiles + } - override fun codeFiles() = getFiles(files) - .filter { it.toFile().length() < 1024 * 1024 / 2 } // Limit to 0.5MB - .map { root.toPath().relativize(it) ?: it }.toSet() + override fun codeFiles() = getFiles(files) + .filter { it.toFile().length() < 1024 * 1024 / 2 } // Limit to 0.5MB + .map { root.toPath().relativize(it) ?: it }.toSet() - override fun codeSummary(paths: List): String = paths - .filter { - val file = settings.workingDirectory?.resolve(it.toFile()) - file?.exists() == true && !file.isDirectory && file.length() < (256 * 1024) - } - .joinToString("\n\n") { path -> - try { - """ + override fun codeSummary(paths: List): String = paths + .filter { + val file = settings.workingDirectory?.resolve(it.toFile()) + file?.exists() == true && !file.isDirectory && file.length() < (256 * 1024) + } + .joinToString("\n\n") { path -> + try { + """ |# ${path} |${tripleTilde}${path.toString().split('.').lastOrNull()} |${settings.workingDirectory?.resolve(path.toFile())?.readText(Charsets.UTF_8)} |${tripleTilde} """.trimMargin() - } catch (e: Exception) { - log.warn("Error reading file", e) - "Error reading file `${path}` - ${e.message}" - } - } - - override fun projectSummary(): String { - val codeFiles = codeFiles() - val str = codeFiles - .asSequence() - .filter { settings.workingDirectory?.toPath()?.resolve(it)?.toFile()?.exists() == true } - .distinct().sorted() - .joinToString("\n") { path -> - "* ${path} - ${ - settings.workingDirectory?.toPath()?.resolve(path)?.toFile()?.length() ?: "?" - } bytes".trim() - } - return str + } catch (e: Exception) { + log.warn("Error reading file", e) + "Error reading file `${path}` - ${e.message}" + } } - override fun output(task: SessionTask): OutputResult = run { - val command = - listOf(settings.executable.absolutePath) + settings.arguments.split(" ").filter(String::isNotBlank) - val processBuilder = ProcessBuilder(command).directory(settings.workingDirectory) - // Pass the current environment to the subprocess - processBuilder.environment().putAll(System.getenv()) - val buffer = StringBuilder() - val taskOutput = task.add("") - val process = processBuilder.start() - Thread { - var lastUpdate = 0L - process.errorStream.bufferedReader().use { reader -> - var line: String? - while (reader.readLine().also { line = it } != null) { - buffer.append(line).append("\n") - if (lastUpdate + TimeUnit.SECONDS.toMillis(15) < System.currentTimeMillis()) { - taskOutput?.set("
\n${truncate(buffer.toString()).htmlEscape}\n
") - task.append("", true) - lastUpdate = System.currentTimeMillis() - } - } - task.append("", true) - } - }.start() - process.inputStream.bufferedReader().use { reader -> - var line: String? - var lastUpdate = 0L - while (reader.readLine().also { line = it } != null) { - buffer.append(line).append("\n") - if (lastUpdate + TimeUnit.SECONDS.toMillis(15) < System.currentTimeMillis()) { - taskOutput?.set("
\n${outputString(buffer).htmlEscape}\n
") - task.append("", true) - lastUpdate = System.currentTimeMillis() - } - } + override fun projectSummary(): String { + val codeFiles = codeFiles() + val str = codeFiles + .asSequence() + .filter { settings.workingDirectory?.toPath()?.resolve(it)?.toFile()?.exists() == true } + .distinct().sorted() + .joinToString("\n") { path -> + "* ${path} - ${ + settings.workingDirectory?.toPath()?.resolve(path)?.toFile()?.length() ?: "?" + } bytes".trim() + } + return str + } + + override fun output(task: SessionTask): OutputResult = run { + val command = + listOf(settings.executable.absolutePath) + settings.arguments.split(" ").filter(String::isNotBlank) + val processBuilder = ProcessBuilder(command).directory(settings.workingDirectory) + // Pass the current environment to the subprocess + processBuilder.environment().putAll(System.getenv()) + val buffer = StringBuilder() + val taskOutput = task.add("") + val process = processBuilder.start() + Thread { + var lastUpdate = 0L + process.errorStream.bufferedReader().use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + buffer.append(line).append("\n") + if (lastUpdate + TimeUnit.SECONDS.toMillis(15) < System.currentTimeMillis()) { + taskOutput?.set("
\n${truncate(buffer.toString()).htmlEscape}\n
") task.append("", true) + lastUpdate = System.currentTimeMillis() + } } - task.append("", false) - if (!process.waitFor(5, TimeUnit.MINUTES)) { - process.destroy() - throw RuntimeException("Process timed out") + task.append("", true) + } + }.start() + process.inputStream.bufferedReader().use { reader -> + var line: String? + var lastUpdate = 0L + while (reader.readLine().also { line = it } != null) { + buffer.append(line).append("\n") + if (lastUpdate + TimeUnit.SECONDS.toMillis(15) < System.currentTimeMillis()) { + taskOutput?.set("
\n${outputString(buffer).htmlEscape}\n
") + task.append("", true) + lastUpdate = System.currentTimeMillis() } - val exitCode = process.exitValue() - var output = outputString(buffer) - taskOutput?.clear() - OutputResult(exitCode, output) + } + task.append("", true) } - - private fun outputString(buffer: StringBuilder): String { - var output = buffer.toString() - output = output.replace(Regex("\\x1B\\[[0-?]*[ -/]*[@-~]"), "") // Remove terminal escape codes - output = truncate(output) - return output + task.append("", false) + if (!process.waitFor(5, TimeUnit.MINUTES)) { + process.destroy() + throw RuntimeException("Process timed out") } + val exitCode = process.exitValue() + var output = outputString(buffer) + taskOutput?.clear() + OutputResult(exitCode, output) + } - override fun searchFiles(searchStrings: List): Set { - return searchStrings.flatMap { searchString -> - FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } - .filter { FileValidationUtils.isLLMIncludableFile(it) } - .filter { it.readText().contains(searchString, ignoreCase = true) } - .map { it.toPath() } - .toList() - }.toSet() - } + private fun outputString(buffer: StringBuilder): String { + var output = buffer.toString() + output = output.replace(Regex("\\x1B\\[[0-?]*[ -/]*[@-~]"), "") // Remove terminal escape codes + output = truncate(output) + return output + } + + override fun searchFiles(searchStrings: List): Set { + return searchStrings.flatMap { searchString -> + FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } + .filter { FileValidationUtils.isLLMIncludableFile(it) } + .filter { it.readText().contains(searchString, ignoreCase = true) } + .map { it.toPath() } + .toList() + }.toSet() + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CommandPatchApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CommandPatchApp.kt index fec3027a..e1797b87 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CommandPatchApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/CommandPatchApp.kt @@ -10,74 +10,74 @@ import java.io.File import java.nio.file.Path class CommandPatchApp( - root: File, - session: Session, - settings: Settings, - api: ChatClient, - model: ChatModel, - private val files: Array?, - val command: String, + root: File, + session: Session, + settings: Settings, + api: ChatClient, + model: ChatModel, + private val files: Array?, + val command: String, ) : PatchApp(root, session, settings, api, model) { - override fun codeFiles() = getFiles(files) - .filter { it.toFile().length() < 1024 * 1024 / 2 } // Limit to 0.5MB - .map { root.toPath().relativize(it) ?: it }.toSet() + override fun codeFiles() = getFiles(files) + .filter { it.toFile().length() < 1024 * 1024 / 2 } // Limit to 0.5MB + .map { root.toPath().relativize(it) ?: it }.toSet() - override fun codeSummary(paths: List): String = paths - .filter { it.toFile().exists() } - .joinToString("\n\n") { path -> - """ + override fun codeSummary(paths: List): String = paths + .filter { it.toFile().exists() } + .joinToString("\n\n") { path -> + """ |# ${settings.workingDirectory?.toPath()?.relativize(path)} |$tripleTilde${path.toString().split('.').lastOrNull()} |${path.toFile().readText(Charsets.UTF_8)} |$tripleTilde """.trimMargin() - } + } - override fun output(task: SessionTask) = OutputResult( - exitCode = 1, - output = command - ) + override fun output(task: SessionTask) = OutputResult( + exitCode = 1, + output = command + ) - override fun projectSummary(): String { - val codeFiles = codeFiles() - return codeFiles - .asSequence() - .filter { settings.workingDirectory?.toPath()?.resolve(it)?.toFile()?.exists() == true } - .distinct().sorted() - .joinToString("\n") { path -> - "* ${path} - ${ - settings.workingDirectory?.toPath()?.resolve(path)?.toFile()?.length() ?: "?" - } bytes".trim() - } - } + override fun projectSummary(): String { + val codeFiles = codeFiles() + return codeFiles + .asSequence() + .filter { settings.workingDirectory?.toPath()?.resolve(it)?.toFile()?.exists() == true } + .distinct().sorted() + .joinToString("\n") { path -> + "* ${path} - ${ + settings.workingDirectory?.toPath()?.resolve(path)?.toFile()?.length() ?: "?" + } bytes".trim() + } + } - override fun searchFiles(searchStrings: List): Set { - return searchStrings.flatMap { searchString -> - FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } - .filter { FileValidationUtils.isLLMIncludableFile(it) } - .filter { it.readText().contains(searchString, ignoreCase = true) } - .map { it.toPath() } - .toList() - }.toSet() - } + override fun searchFiles(searchStrings: List): Set { + return searchStrings.flatMap { searchString -> + FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } + .filter { FileValidationUtils.isLLMIncludableFile(it) } + .filter { it.readText().contains(searchString, ignoreCase = true) } + .map { it.toPath() } + .toList() + }.toSet() + } - companion object { - fun getFiles( - files: Array? - ): MutableSet { - val codeFiles = mutableSetOf() // Set to avoid duplicates - files?.forEach { file -> - if (file.isDirectory) { - if (file.name.startsWith(".")) return@forEach - if (FileValidationUtils.isGitignore(file.toPath())) return@forEach - if (file.name.endsWith(".png")) return@forEach - if (file.length() > 1024 * 256) return@forEach - codeFiles.addAll(getFiles(file.listFiles())) - } else { - codeFiles.add((file.toPath())) - } - } - return codeFiles + companion object { + fun getFiles( + files: Array? + ): MutableSet { + val codeFiles = mutableSetOf() // Set to avoid duplicates + files?.forEach { file -> + if (file.isDirectory) { + if (file.name.startsWith(".")) return@forEach + if (FileValidationUtils.isGitignore(file.toPath())) return@forEach + if (file.name.endsWith(".png")) return@forEach + if (file.length() > 1024 * 256) return@forEach + codeFiles.addAll(getFiles(file.listFiles())) + } else { + codeFiles.add((file.toPath())) } + } + return codeFiles } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PatchApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PatchApp.kt index 43799165..03b41af1 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PatchApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PatchApp.kt @@ -24,178 +24,180 @@ import java.nio.file.Path import java.util.UUID abstract class PatchApp( - override val root: File, - val session: Session, - val settings: Settings, - val api: ChatClient, - val model: ChatModel, - val promptPrefix: String = """The following command was run and produced an error:""" + override val root: File, + val session: Session, + val settings: Settings, + val api: ChatClient, + val model: ChatModel, + val promptPrefix: String = """The following command was run and produced an error:""" ) : ApplicationServer( - applicationName = "Magic Code Fixer", - path = "/fixCmd", - showMenubar = false, + applicationName = "Magic Code Fixer", + path = "/fixCmd", + showMenubar = false, ) { - companion object { - private val log = LoggerFactory.getLogger(PatchApp::class.java) - const val tripleTilde = "`" + "``" // This is a workaround for the markdown parser when editing this file - } + companion object { + private val log = LoggerFactory.getLogger(PatchApp::class.java) + const val tripleTilde = "`" + "``" // This is a workaround for the markdown parser when editing this file + } - data class OutputResult(val exitCode: Int, val output: String) + data class OutputResult(val exitCode: Int, val output: String) - abstract fun codeFiles(): Set - abstract fun codeSummary(paths: List): String - abstract fun output(task: SessionTask): OutputResult - abstract fun searchFiles(searchStrings: List): Set - override val singleInput = true - override val stickyInput = false - override fun newSession(user: User?, session: Session): SocketManager { - val socketManager = super.newSession(user, session) - val ui = (socketManager as ApplicationSocketManager).applicationInterface - val task = ui.newTask() - Retryable( - ui = ui, - task = task, - process = { content -> - val newTask = ui.newTask(false) - newTask.add("Running Command") - Thread { - run(ui, newTask) - }.start() - newTask.placeholder - } - ) - return socketManager - } + abstract fun codeFiles(): Set + abstract fun codeSummary(paths: List): String + abstract fun output(task: SessionTask): OutputResult + abstract fun searchFiles(searchStrings: List): Set + override val singleInput = true + override val stickyInput = false + override fun newSession(user: User?, session: Session): SocketManager { + val socketManager = super.newSession(user, session) + val ui = (socketManager as ApplicationSocketManager).applicationInterface + val task = ui.newTask() + Retryable( + ui = ui, + task = task, + process = { content -> + val newTask = ui.newTask(false) + newTask.add("Running Command") + Thread { + run(ui, newTask) + }.start() + newTask.placeholder + } + ) + return socketManager + } - abstract fun projectSummary(): String + abstract fun projectSummary(): String - private fun prunePaths(paths: List, maxSize: Int): List { - val sortedPaths = paths.sortedByDescending { it.toFile().length() } - var totalSize = 0 - val prunedPaths = mutableListOf() - for (path in sortedPaths) { - val fileSize = path.toFile().length().toInt() - if (totalSize + fileSize > maxSize) break - prunedPaths.add(path) - totalSize += fileSize - } - return prunedPaths + private fun prunePaths(paths: List, maxSize: Int): List { + val sortedPaths = paths.sortedByDescending { it.toFile().length() } + var totalSize = 0 + val prunedPaths = mutableListOf() + for (path in sortedPaths) { + val fileSize = path.toFile().length().toInt() + if (totalSize + fileSize > maxSize) break + prunedPaths.add(path) + totalSize += fileSize } + return prunedPaths + } - data class ParsedErrors( - val errors: List? = null - ) + data class ParsedErrors( + val errors: List? = null + ) - data class ParsedError( - @Description("The error message") - val message: String? = null, - @Description("Files identified as needing modification and issue-related files") - val relatedFiles: List? = null, - @Description("Files identified as needing modification and issue-related files") - val fixFiles: List? = null, - @Description("Search strings to find relevant files") - val searchStrings: List? = null - ) + data class ParsedError( + @Description("The error message") + val message: String? = null, + @Description("Files identified as needing modification and issue-related files") + val relatedFiles: List? = null, + @Description("Files identified as needing modification and issue-related files") + val fixFiles: List? = null, + @Description("Search strings to find relevant files") + val searchStrings: List? = null + ) - data class Settings( - var executable: File, - var arguments: String = "", - var workingDirectory: File? = null, - var exitCodeOption: String = "nonzero", - var additionalInstructions: String = "", - val autoFix: Boolean, - ) + data class Settings( + var executable: File, + var arguments: String = "", + var workingDirectory: File? = null, + var exitCodeOption: String = "nonzero", + var additionalInstructions: String = "", + val autoFix: Boolean, + ) - fun run( - ui: ApplicationInterface, - task: SessionTask, - ): OutputResult { - val output = output(task) - if (output.exitCode == 0 && settings.exitCodeOption == "nonzero") { - task.complete( - """ + fun run( + ui: ApplicationInterface, + task: SessionTask, + ): OutputResult { + val output = output(task) + if (output.exitCode == 0 && settings.exitCodeOption == "nonzero") { + task.complete( + """ |
|
Command executed successfully
|${MarkdownUtil.renderMarkdown("${tripleTilde}\n${output.output}\n${tripleTilde}")} |
|""".trimMargin() - ) - return output - } - if (settings.exitCodeOption == "zero" && output.exitCode != 0) { - task.complete( - """ + ) + return output + } + if (settings.exitCodeOption == "zero" && output.exitCode != 0) { + task.complete( + """ |
|
Command failed
|${MarkdownUtil.renderMarkdown("${tripleTilde}\n${output.output}\n${tripleTilde}")} |
|""".trimMargin() - ) - return output - } - try { - task.add( - """ + ) + return output + } + try { + task.add( + """ |
|
Command exit code: ${output.exitCode}
|${MarkdownUtil.renderMarkdown("${tripleTilde}\n${output.output}\n${tripleTilde}")} |
""".trimMargin() - ) - fixAll(settings, output, task, ui, api) - } catch (e: Exception) { - task.error(ui, e) - } - return output + ) + fixAll(settings, output, task, ui, api) + } catch (e: Exception) { + task.error(ui, e) } + return output + } - private fun fixAll( - settings: Settings, - output: OutputResult, - task: SessionTask, - ui: ApplicationInterface, - api: ChatClient, - ) { - Retryable(ui, task) { content -> - fixAllInternal( - settings = settings, - output = output, - task = task, - ui = ui, - changed = mutableSetOf(), - api = api - ) - content.clear() - "" - } + private fun fixAll( + settings: Settings, + output: OutputResult, + task: SessionTask, + ui: ApplicationInterface, + api: ChatClient, + ) { + Retryable(ui, task) { content -> + fixAllInternal( + settings = settings, + output = output, + task = task, + ui = ui, + changed = mutableSetOf(), + api = api + ) + content.clear() + "" } + } - private fun fixAllInternal( - settings: Settings, - output: OutputResult, - task: SessionTask, - ui: ApplicationInterface, - changed: MutableSet, - api: ChatClient, - ) { - val api = api.getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - val plan = ParsedActor( - resultClass = ParsedErrors::class.java, - exampleInstance = ParsedErrors(listOf( - ParsedError( - message = "Error message", - relatedFiles = listOf("src/main/java/com/example/Example.java"), - fixFiles = listOf("src/main/java/com/example/Example.java"), - searchStrings = listOf("def exampleFunction", "TODO") - ) - )), - prompt = """ + private fun fixAllInternal( + settings: Settings, + output: OutputResult, + task: SessionTask, + ui: ApplicationInterface, + changed: MutableSet, + api: ChatClient, + ) { + val api = api.getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } + val plan = ParsedActor( + resultClass = ParsedErrors::class.java, + exampleInstance = ParsedErrors( + listOf( + ParsedError( + message = "Error message", + relatedFiles = listOf("src/main/java/com/example/Example.java"), + fixFiles = listOf("src/main/java/com/example/Example.java"), + searchStrings = listOf("def exampleFunction", "TODO") + ) + ) + ), + prompt = """ |You are a helpful AI that helps people with coding. | |You will be answering questions about the following project: @@ -212,89 +214,89 @@ abstract class PatchApp( | 3) specify a search string to find relevant files - be as specific as possible |${if (settings.additionalInstructions.isNotBlank()) "Additional Instructions:\n ${settings.additionalInstructions}\n" else ""} """.trimMargin(), - model = model - ).answer( - listOf( - """ + model = model + ).answer( + listOf( + """ |$promptPrefix | |${tripleTilde} |${output.output} |${tripleTilde} """.trimMargin() - ), api = api - ) - task.add( - AgentPatterns.displayMapInTabs( - mapOf( - "Text" to MarkdownUtil.renderMarkdown(plan.text, ui = ui), - "JSON" to MarkdownUtil.renderMarkdown( - "${tripleTilde}json\n${JsonUtil.toJson(plan.obj)}\n${tripleTilde}", - ui = ui - ), - ) - ) + ), api = api + ) + task.add( + AgentPatterns.displayMapInTabs( + mapOf( + "Text" to MarkdownUtil.renderMarkdown(plan.text, ui = ui), + "JSON" to MarkdownUtil.renderMarkdown( + "${tripleTilde}json\n${JsonUtil.toJson(plan.obj)}\n${tripleTilde}", + ui = ui + ), ) - val progressHeader = task.header("Processing tasks") - plan.obj.errors?.forEach { error -> - task.header("Processing error: ${error.message}") - task.verbose(MarkdownUtil.renderMarkdown("```json\n${JsonUtil.toJson(error)}\n```", tabs = false, ui = ui)) - // Search for files using the provided search strings - val searchResults = error.searchStrings?.flatMap { searchString -> - FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } - .filter { FileValidationUtils.isLLMIncludableFile(it) } - .filter { it.readText().contains(searchString, ignoreCase = true) } - .map { it.toPath() } - .toList() - }?.toSet() ?: emptySet() - task.verbose( - MarkdownUtil.renderMarkdown( - """ + ) + ) + val progressHeader = task.header("Processing tasks") + plan.obj.errors?.forEach { error -> + task.header("Processing error: ${error.message}") + task.verbose(MarkdownUtil.renderMarkdown("```json\n${JsonUtil.toJson(error)}\n```", tabs = false, ui = ui)) + // Search for files using the provided search strings + val searchResults = error.searchStrings?.flatMap { searchString -> + FileValidationUtils.filteredWalk(settings.workingDirectory!!) { !FileValidationUtils.isGitignore(it.toPath()) } + .filter { FileValidationUtils.isLLMIncludableFile(it) } + .filter { it.readText().contains(searchString, ignoreCase = true) } + .map { it.toPath() } + .toList() + }?.toSet() ?: emptySet() + task.verbose( + MarkdownUtil.renderMarkdown( + """ |Search results: | |${searchResults.joinToString("\n") { "* `$it`" }} """.trimMargin(), tabs = false, ui = ui - ) - ) - Retryable(ui, task) { content -> - fix( - error, searchResults.toList().map { it.toFile().absolutePath }, - output, ui, content, settings.autoFix, changed, api - ) - content.toString() - } - } - progressHeader?.clear() - task.append("", false) + ) + ) + Retryable(ui, task) { content -> + fix( + error, searchResults.toList().map { it.toFile().absolutePath }, + output, ui, content, settings.autoFix, changed, api + ) + content.toString() + } } + progressHeader?.clear() + task.append("", false) + } - private fun fix( - error: ParsedError, - additionalFiles: List? = null, - output: OutputResult, - ui: ApplicationInterface, - content: StringBuilder, - autoFix: Boolean, - changed: MutableSet, - api: ChatClient, - ) { - val paths = - ( - (error.fixFiles ?: emptyList()) + - (error.relatedFiles ?: emptyList()) + - (additionalFiles ?: emptyList()) - ).map { - try { - File(it).toPath() - } catch (e: Throwable) { - log.warn("Error: root=${root} ", e) - null - } - }.filterNotNull() - val prunedPaths = prunePaths(paths, 50 * 1024) - val summary = codeSummary(prunedPaths) - val response = SimpleActor( - prompt = """ + private fun fix( + error: ParsedError, + additionalFiles: List? = null, + output: OutputResult, + ui: ApplicationInterface, + content: StringBuilder, + autoFix: Boolean, + changed: MutableSet, + api: ChatClient, + ) { + val paths = + ( + (error.fixFiles ?: emptyList()) + + (error.relatedFiles ?: emptyList()) + + (additionalFiles ?: emptyList()) + ).map { + try { + File(it).toPath() + } catch (e: Throwable) { + log.warn("Error: root=${root} ", e) + null + } + }.filterNotNull() + val prunedPaths = prunePaths(paths, 50 * 1024) + val summary = codeSummary(prunedPaths) + val response = SimpleActor( + prompt = """ |You are a helpful AI that helps people with coding. | |You will be answering questions about the following code: @@ -337,10 +339,10 @@ abstract class PatchApp( | |If needed, new files can be created by using code blocks labeled with the filename in the same manner. """.trimMargin(), - model = model - ).answer( - listOf( - """ + model = model + ).answer( + listOf( + """ |$promptPrefix | |${tripleTilde} @@ -351,26 +353,26 @@ abstract class PatchApp( | ${error.message?.replace("\n", "\n ") ?: ""} |${if (settings.additionalInstructions.isNotBlank()) "Additional Instructions:\n ${settings.additionalInstructions}\n" else ""} """.trimMargin() - ), api = api - ) - var markdown = ui.socketManager?.addApplyFileDiffLinks( - root = root.toPath(), - response = response, - ui = ui, - api = api, - shouldAutoApply = { path -> - if (autoFix && !changed.contains(path)) { - changed.add(path) - true - } else { - false - } - }, - model = model, - ) - content.clear() - content.append("
${MarkdownUtil.renderMarkdown(markdown!!)}
") - } + ), api = api + ) + var markdown = ui.socketManager?.addApplyFileDiffLinks( + root = root.toPath(), + response = response, + ui = ui, + api = api, + shouldAutoApply = { path -> + if (autoFix && !changed.contains(path)) { + changed.add(path) + true + } else { + false + } + }, + model = model, + ) + content.clear() + content.append("
${MarkdownUtil.renderMarkdown(markdown!!)}
") + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanAheadApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanAheadApp.kt index ab725922..07dbfe64 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanAheadApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanAheadApp.kt @@ -21,93 +21,93 @@ import org.slf4j.LoggerFactory import java.io.File open class PlanAheadApp( - applicationName: String = "Task Planning v1.1", - path: String = "/taskDev", - val planSettings: PlanSettings, - val model: ChatModel, - val parsingModel: ChatModel, - val domainName: String = "localhost", - showMenubar: Boolean = true, - val initialPlan: TaskBreakdownWithPrompt? = null, - val api: API? = null, - val api2: OpenAIClient, + applicationName: String = "Task Planning v1.1", + path: String = "/taskDev", + val planSettings: PlanSettings, + val model: ChatModel, + val parsingModel: ChatModel, + val domainName: String = "localhost", + showMenubar: Boolean = true, + val initialPlan: TaskBreakdownWithPrompt? = null, + val api: API? = null, + val api2: OpenAIClient, ) : ApplicationServer( - applicationName = applicationName, - path = path, - showMenubar = showMenubar, - root = planSettings.workingDir?.let { File(it) } ?: dataStorageRoot, + applicationName = applicationName, + path = path, + showMenubar = showMenubar, + root = planSettings.workingDir?.let { File(it) } ?: dataStorageRoot, ) { - override val singleInput = true + override val singleInput = true - @Suppress("UNCHECKED_CAST") - override fun initSettings(session: Session): T = planSettings.let { - if (null == root) it.copy(workingDir = root.absolutePath) else - it - } as T + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T = planSettings.let { + if (null == root) it.copy(workingDir = root.absolutePath) else + it + } as T - override fun newSession(user: User?, session: Session): SocketManager { - val socketManager = super.newSession(user, session) - val ui = (socketManager as ApplicationSocketManager).applicationInterface - if (initialPlan != null) { - socketManager.pool.submit { - try { - val planSettings = getSettings(session, user, PlanSettings::class.java) - if (api is ChatClient) api.budget = planSettings?.budget - val coordinator = PlanCoordinator( - user = user, - session = session, - dataStorage = dataStorage, - ui = ui, - root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), - planSettings = planSettings!! - ) - coordinator.executeTaskBreakdownWithPrompt(JsonUtil.toJson(initialPlan), api!!, api2, ui.newTask()) - } catch (e: Throwable) { - ui.newTask().error(ui, e) - log.warn("Error", e) - } - } - } - return socketManager - } - - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { + override fun newSession(user: User?, session: Session): SocketManager { + val socketManager = super.newSession(user, session) + val ui = (socketManager as ApplicationSocketManager).applicationInterface + if (initialPlan != null) { + socketManager.pool.submit { try { - val planSettings = getSettings(session, user, PlanSettings::class.java) - if (api is ChatClient) api.budget = planSettings?.budget ?: 2.0 - val coordinator = PlanCoordinator( - user = user, - session = session, - dataStorage = dataStorage, - ui = ui, - root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), - planSettings = planSettings!! - ) - val task = ui.newTask() - val plan = initialPlan( - codeFiles = coordinator.codeFiles, - files = coordinator.files, - root = coordinator.root, - task = task, - userMessage = userMessage, - ui = coordinator.ui, - planSettings = coordinator.planSettings, - api = api - ) - coordinator.executePlan(plan.plan, task, userMessage = userMessage, api = api, api2 = api2) + val planSettings = getSettings(session, user, PlanSettings::class.java) + if (api is ChatClient) api.budget = planSettings?.budget + val coordinator = PlanCoordinator( + user = user, + session = session, + dataStorage = dataStorage, + ui = ui, + root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), + planSettings = planSettings!! + ) + coordinator.executeTaskBreakdownWithPrompt(JsonUtil.toJson(initialPlan), api!!, api2, ui.newTask()) } catch (e: Throwable) { - ui.newTask().error(ui, e) - log.warn("Error", e) + ui.newTask().error(ui, e) + log.warn("Error", e) } + } } + return socketManager + } - companion object { - private val log = LoggerFactory.getLogger(PlanAheadApp::class.java) + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + try { + val planSettings = getSettings(session, user, PlanSettings::class.java) + if (api is ChatClient) api.budget = planSettings?.budget ?: 2.0 + val coordinator = PlanCoordinator( + user = user, + session = session, + dataStorage = dataStorage, + ui = ui, + root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), + planSettings = planSettings!! + ) + val task = ui.newTask() + val plan = initialPlan( + codeFiles = coordinator.codeFiles, + files = coordinator.files, + root = coordinator.root, + task = task, + userMessage = userMessage, + ui = coordinator.ui, + planSettings = coordinator.planSettings, + api = api + ) + coordinator.executePlan(plan.plan, task, userMessage = userMessage, api = api, api2 = api2) + } catch (e: Throwable) { + ui.newTask().error(ui, e) + log.warn("Error", e) } + } + + companion object { + private val log = LoggerFactory.getLogger(PlanAheadApp::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanChatApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanChatApp.kt index a6649382..c8fd2c16 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanChatApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/PlanChatApp.kt @@ -16,146 +16,146 @@ import java.io.File import java.util.* open class PlanChatApp( - applicationName: String = "Task Planning Chat v1.0", - path: String = "/taskChat", - planSettings: PlanSettings, - model: ChatModel, - parsingModel: ChatModel, - domainName: String = "localhost", - showMenubar: Boolean = true, - initialPlan: TaskBreakdownWithPrompt? = null, - api: API? = null, - api2: OpenAIClient, + applicationName: String = "Task Planning Chat v1.0", + path: String = "/taskChat", + planSettings: PlanSettings, + model: ChatModel, + parsingModel: ChatModel, + domainName: String = "localhost", + showMenubar: Boolean = true, + initialPlan: TaskBreakdownWithPrompt? = null, + api: API? = null, + api2: OpenAIClient, ) : PlanAheadApp( - applicationName = applicationName, - path = path, - planSettings = planSettings, - model = model, - parsingModel = parsingModel, - domainName = domainName, - showMenubar = showMenubar, - initialPlan = initialPlan, - api = api, - api2 = api2, + applicationName = applicationName, + path = path, + planSettings = planSettings, + model = model, + parsingModel = parsingModel, + domainName = domainName, + showMenubar = showMenubar, + initialPlan = initialPlan, + api = api, + api2 = api2, ) { - override val stickyInput = true - override val singleInput = false + override val stickyInput = true + override val singleInput = false - private val sessionHandlers: MutableMap = mutableMapOf() - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - ui.socketManager?.pool!!.submit { - val sessionHandler = sessionHandlers.getOrPut(session.sessionId) { - ChatSessionHandler( - ui = ui, - session = session, - user = user, - api = api, - api2 = api2, - ) - } - sessionHandler.handleUserMessage( - userMessage = userMessage, - ) - } + private val sessionHandlers: MutableMap = mutableMapOf() + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + ui.socketManager?.pool!!.submit { + val sessionHandler = sessionHandlers.getOrPut(session.sessionId) { + ChatSessionHandler( + ui = ui, + session = session, + user = user, + api = api, + api2 = api2, + ) + } + sessionHandler.handleUserMessage( + userMessage = userMessage, + ) } + } - private inner class ChatSessionHandler( - val ui: ApplicationInterface, - val session: Session, - val user: User?, - val api: API, - val api2: OpenAIClient, - ) { - val messageHistory: MutableList = mutableListOf() + private inner class ChatSessionHandler( + val ui: ApplicationInterface, + val session: Session, + val user: User?, + val api: API, + val api2: OpenAIClient, + ) { + val messageHistory: MutableList = mutableListOf() - fun handleUserMessage(userMessage: String) { - try { - messageHistory.add(userMessage) - val planSettings = (getSettings(session, user, PlanSettings::class.java) ?: PlanSettings( - defaultModel = model, - parsingModel = parsingModel, - command = planSettings.command, - temperature = planSettings.temperature, - workingDir = planSettings.workingDir, - env = planSettings.env, - githubToken = planSettings.githubToken, - googleApiKey = planSettings.googleApiKey, - googleSearchEngineId = planSettings.googleSearchEngineId, - )).copy( - allowBlocking = false, - ) - if (api is ChatClient) api.budget = planSettings.budget - val coordinator = PlanCoordinator( - user = user, - session = session, - dataStorage = dataStorage, - ui = ui, - root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), - planSettings = planSettings - ) - val mainTask = ui.newTask() - val sessionTask = ui.newTask(false).apply { mainTask.verbose(placeholder) } - val api = (api as ChatClient).getChildClient().apply { - val createFile = sessionTask.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - sessionTask.verbose("API log: $this") - } - } - val plan = PlanCoordinator.initialPlan( - codeFiles = coordinator.codeFiles, - files = coordinator.files, - root = dataStorage.getDataDir(user, session).toPath(), - task = sessionTask, - userMessage = userMessage, - ui = coordinator.ui, - planSettings = coordinator.planSettings, - api = api - ) - val modifiedPlan = addRespondToChatTask(plan.plan) - val planProcessingState = coordinator.executePlan( - plan = modifiedPlan, - task = sessionTask, - userMessage = userMessage, - api = api, - api2 = api2, - ) - val response = planProcessingState.taskResult["respond_to_chat"] as? String - if (response != null) { - mainTask.add(MarkdownUtil.renderMarkdown(response, ui = ui)) - messageHistory.add(response) - } else { - mainTask.add("Sorry, I couldn't generate a response.") - messageHistory.add("Sorry, I couldn't generate a response.") - } - mainTask.complete() - } catch (e: Throwable) { - ui.newTask().error(ui, e) - log.warn("Error", e) - } + fun handleUserMessage(userMessage: String) { + try { + messageHistory.add(userMessage) + val planSettings = (getSettings(session, user, PlanSettings::class.java) ?: PlanSettings( + defaultModel = model, + parsingModel = parsingModel, + command = planSettings.command, + temperature = planSettings.temperature, + workingDir = planSettings.workingDir, + env = planSettings.env, + githubToken = planSettings.githubToken, + googleApiKey = planSettings.googleApiKey, + googleSearchEngineId = planSettings.googleSearchEngineId, + )).copy( + allowBlocking = false, + ) + if (api is ChatClient) api.budget = planSettings.budget + val coordinator = PlanCoordinator( + user = user, + session = session, + dataStorage = dataStorage, + ui = ui, + root = planSettings?.workingDir?.let { File(it).toPath() } ?: dataStorage.getDataDir(user, session).toPath(), + planSettings = planSettings + ) + val mainTask = ui.newTask() + val sessionTask = ui.newTask(false).apply { mainTask.verbose(placeholder) } + val api = (api as ChatClient).getChildClient().apply { + val createFile = sessionTask.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + sessionTask.verbose("API log: $this") + } } - + val plan = PlanCoordinator.initialPlan( + codeFiles = coordinator.codeFiles, + files = coordinator.files, + root = dataStorage.getDataDir(user, session).toPath(), + task = sessionTask, + userMessage = userMessage, + ui = coordinator.ui, + planSettings = coordinator.planSettings, + api = api + ) + val modifiedPlan = addRespondToChatTask(plan.plan) + val planProcessingState = coordinator.executePlan( + plan = modifiedPlan, + task = sessionTask, + userMessage = userMessage, + api = api, + api2 = api2, + ) + val response = planProcessingState.taskResult["respond_to_chat"] as? String + if (response != null) { + mainTask.add(MarkdownUtil.renderMarkdown(response, ui = ui)) + messageHistory.add(response) + } else { + mainTask.add("Sorry, I couldn't generate a response.") + messageHistory.add("Sorry, I couldn't generate a response.") + } + mainTask.complete() + } catch (e: Throwable) { + ui.newTask().error(ui, e) + log.warn("Error", e) + } } - protected open fun addRespondToChatTask(plan: Map): Map { - val tasksByID = plan?.toMutableMap() ?: mutableMapOf() - val respondTaskId = "respond_to_chat" + } - tasksByID[respondTaskId] = InquiryTaskData( - task_description = "Respond to the user's chat message based on the executed plan", - task_dependencies = tasksByID.keys.toList() - ) + protected open fun addRespondToChatTask(plan: Map): Map { + val tasksByID = plan?.toMutableMap() ?: mutableMapOf() + val respondTaskId = "respond_to_chat" - return tasksByID - } + tasksByID[respondTaskId] = InquiryTaskData( + task_description = "Respond to the user's chat message based on the executed plan", + task_dependencies = tasksByID.keys.toList() + ) - companion object { - private val log = LoggerFactory.getLogger(PlanChatApp::class.java) - } + return tasksByID + } + + companion object { + private val log = LoggerFactory.getLogger(PlanChatApp::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/StressTestApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/StressTestApp.kt index d9641799..b5ae9ff9 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/StressTestApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/StressTestApp.kt @@ -12,59 +12,59 @@ import org.slf4j.LoggerFactory import kotlin.random.Random class StressTestApp( - applicationName: String = "UI Stress Test", - path: String = "/stressTest", + applicationName: String = "UI Stress Test", + path: String = "/stressTest", ) : ApplicationServer( - applicationName = applicationName, - path = path, - showMenubar = true + applicationName = applicationName, + path = path, + showMenubar = true ) { - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - val task = ui.newTask() - task.add(MarkdownUtil.renderMarkdown("# UI Stress Test", ui = ui)) + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + val task = ui.newTask() + task.add(MarkdownUtil.renderMarkdown("# UI Stress Test", ui = ui)) - // Create nested tabs - createNestedTabs(task, ui, 3) + // Create nested tabs + createNestedTabs(task, ui, 3) - } + } - private fun createNestedTabs(task: SessionTask, ui: ApplicationInterface, depth: Int) { - if (depth <= 0) { - // Create a complex diagram - createComplexDiagram(task, ui) + private fun createNestedTabs(task: SessionTask, ui: ApplicationInterface, depth: Int) { + if (depth <= 0) { + // Create a complex diagram + createComplexDiagram(task, ui) - // Create multiple placeholders and update them - createAndUpdatePlaceholders(task, ui) - return - } + // Create multiple placeholders and update them + createAndUpdatePlaceholders(task, ui) + return + } - val tabDisplay = object : TabbedDisplay(task) { - override fun renderTabButtons(): String { - return buildString { - append("
\n") - (1..3).forEach { i -> - append("\n") - } - append("
") - } - } + val tabDisplay = object : TabbedDisplay(task) { + override fun renderTabButtons(): String { + return buildString { + append("
\n") + (1..3).forEach { i -> + append("\n") + } + append("
") } + } + } - (1..3).forEach { i -> - val subTask = ui.newTask(false) - tabDisplay["Tab $i"] = subTask.placeholder - createNestedTabs(subTask, ui, depth - 1) - } + (1..3).forEach { i -> + val subTask = ui.newTask(false) + tabDisplay["Tab $i"] = subTask.placeholder + createNestedTabs(subTask, ui, depth - 1) } + } - private fun createComplexDiagram(task: SessionTask, ui: ApplicationInterface) { - val mermaidDiagram = """ + private fun createComplexDiagram(task: SessionTask, ui: ApplicationInterface) { + val mermaidDiagram = """ ```mermaid graph TD A[Start] --> B{Is it?} @@ -75,27 +75,27 @@ class StressTestApp( ``` """.trimIndent() - task.add(MarkdownUtil.renderMarkdown("## Complex Diagram\n$mermaidDiagram", ui = ui)) - } - - private fun createAndUpdatePlaceholders(task: SessionTask, ui: ApplicationInterface) { - val placeholders = (1..5).map { ui.newTask(false) } + task.add(MarkdownUtil.renderMarkdown("## Complex Diagram\n$mermaidDiagram", ui = ui)) + } - placeholders.forEach { placeholder -> - task.add(placeholder.placeholder) - } + private fun createAndUpdatePlaceholders(task: SessionTask, ui: ApplicationInterface) { + val placeholders = (1..5).map { ui.newTask(false) } - repeat(10) { iteration -> - placeholders.forEach { placeholder -> - val content = "Placeholder content: Iteration $iteration, Random: ${Random.nextInt(100)}" - placeholder.add(MarkdownUtil.renderMarkdown(content, ui = ui)) - //Thread.sleep(50) - } - } - placeholders.forEach { it.complete() } + placeholders.forEach { placeholder -> + task.add(placeholder.placeholder) } - companion object { - private val log = LoggerFactory.getLogger(StressTestApp::class.java) + repeat(10) { iteration -> + placeholders.forEach { placeholder -> + val content = "Placeholder content: Iteration $iteration, Random: ${Random.nextInt(100)}" + placeholder.add(MarkdownUtil.renderMarkdown(content, ui = ui)) + //Thread.sleep(50) + } } + placeholders.forEach { it.complete() } + } + + companion object { + private val log = LoggerFactory.getLogger(StressTestApp::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt index cfe33195..5c582fe2 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/general/WebDevApp.kt @@ -33,64 +33,64 @@ import javax.imageio.ImageIO import kotlin.io.path.name open class WebDevApp( - applicationName: String = "Web Dev Assistant v1.2", - open val symbols: Map = mapOf(), - val temperature: Double = 0.1, + applicationName: String = "Web Dev Assistant v1.2", + open val symbols: Map = mapOf(), + val temperature: Double = 0.1, ) : ApplicationServer( - applicationName = applicationName, - path = "/webdev", + applicationName = applicationName, + path = "/webdev", ) { - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - val settings = getSettings(session, user) ?: Settings() - (api as ChatClient).budget = settings.budget ?: 2.00 - WebDevAgent( - api = api, - dataStorage = dataStorage, - session = session, - user = user, - ui = ui, - tools = settings.tools, - model = settings.model, - parsingModel = settings.parsingModel, - root = root, - ).start( - userMessage = userMessage, - ) - } - - data class Settings( - val budget: Double? = 2.00, - val tools: List = emptyList(), - val model: ChatModel = OpenAIModels.GPT4o, - val parsingModel: ChatModel = OpenAIModels.GPT4oMini, + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + val settings = getSettings(session, user) ?: Settings() + (api as ChatClient).budget = settings.budget ?: 2.00 + WebDevAgent( + api = api, + dataStorage = dataStorage, + session = session, + user = user, + ui = ui, + tools = settings.tools, + model = settings.model, + parsingModel = settings.parsingModel, + root = root, + ).start( + userMessage = userMessage, ) + } - override val settingsClass: Class<*> get() = Settings::class.java + data class Settings( + val budget: Double? = 2.00, + val tools: List = emptyList(), + val model: ChatModel = OpenAIModels.GPT4o, + val parsingModel: ChatModel = OpenAIModels.GPT4oMini, + ) - @Suppress("UNCHECKED_CAST") - override fun initSettings(session: Session): T? = Settings() as T + override val settingsClass: Class<*> get() = Settings::class.java + + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T? = Settings() as T } class WebDevAgent( - val api: API, - dataStorage: StorageInterface, - session: Session, - user: User?, - val ui: ApplicationInterface, - val model: ChatModel, - val parsingModel: ChatModel, - val tools: List = emptyList(), - @Language("Markdown") val actorMap: Map> = mapOf( - ActorTypes.ArchitectureDiscussionActor to ParsedActor( + val api: API, + dataStorage: StorageInterface, + session: Session, + user: User?, + val ui: ApplicationInterface, + val model: ChatModel, + val parsingModel: ChatModel, + val tools: List = emptyList(), + @Language("Markdown") val actorMap: Map> = mapOf( + ActorTypes.ArchitectureDiscussionActor to ParsedActor( // parserClass = PageResourceListParser::class.java, - resultClass = ProjectSpec::class.java, - prompt = """ + resultClass = ProjectSpec::class.java, + prompt = """ |Translate the user's idea into a detailed architecture for a simple web application. | | List all html, css, javascript, and image files to be created, and for each file: @@ -101,11 +101,11 @@ class WebDevAgent( |Specify user interactions and how the application will respond to them. |Identify key HTML classes and element IDs that will be used to bind the application to the HTML. """.trimMargin(), - model = model, - parsingModel = parsingModel, - ), - ActorTypes.CodeReviewer to SimpleActor( - prompt = """ + model = model, + parsingModel = parsingModel, + ), + ActorTypes.CodeReviewer to SimpleActor( + prompt = """ |Analyze the code summarized in the user's header-labeled code blocks. |Review, look for bugs, and provide fixes. |Provide implementations for missing functions. @@ -143,447 +143,447 @@ class WebDevAgent( | }); |``` """.trimMargin(), - model = model, - ), - ActorTypes.HtmlCodingActor to SimpleActor( - prompt = """ + model = model, + ), + ActorTypes.HtmlCodingActor to SimpleActor( + prompt = """ |You will translate the user request into a skeleton HTML file for a rich javascript application. |The html file can reference needed CSS and JS files, which are will be located in the same directory as the html file. |Do not output the content of the resource files, only the html file. """.trimMargin(), model = model - ), - ActorTypes.JavascriptCodingActor to SimpleActor( - prompt = """ + ), + ActorTypes.JavascriptCodingActor to SimpleActor( + prompt = """ |You will translate the user request into a javascript file for use in a rich javascript application. """.trimMargin(), model = model - ), - ActorTypes.CssCodingActor to SimpleActor( - prompt = """ + ), + ActorTypes.CssCodingActor to SimpleActor( + prompt = """ |You will translate the user request into a CSS file for use in a rich javascript application. """.trimMargin(), model = model - ), - ActorTypes.EtcCodingActor to SimpleActor( - prompt = """ + ), + ActorTypes.EtcCodingActor to SimpleActor( + prompt = """ |You will translate the user request into a file for use in a web application. """.trimMargin(), - model = model, - ), - ActorTypes.ImageActor to ImageActor( - prompt = """ + model = model, + ), + ActorTypes.ImageActor to ImageActor( + prompt = """ |You will translate the user request into an image file for use in a web application. """.trimMargin(), - textModel = model, - imageModel = ImageModels.DallE3, - ), + textModel = model, + imageModel = ImageModels.DallE3, ), - val root: File, + ), + val root: File, ) : ActorSystem(actorMap.map { it.key.name to it.value }.toMap(), dataStorage, user, session) { - enum class ActorTypes { - HtmlCodingActor, - JavascriptCodingActor, - CssCodingActor, - ArchitectureDiscussionActor, - CodeReviewer, - EtcCodingActor, - ImageActor, - } + enum class ActorTypes { + HtmlCodingActor, + JavascriptCodingActor, + CssCodingActor, + ArchitectureDiscussionActor, + CodeReviewer, + EtcCodingActor, + ImageActor, + } - private val architectureDiscussionActor by lazy { getActor(ActorTypes.ArchitectureDiscussionActor) as ParsedActor } - private val htmlActor by lazy { getActor(ActorTypes.HtmlCodingActor) as SimpleActor } - private val imageActor by lazy { getActor(ActorTypes.ImageActor) as ImageActor } - private val javascriptActor by lazy { getActor(ActorTypes.JavascriptCodingActor) as SimpleActor } - private val cssActor by lazy { getActor(ActorTypes.CssCodingActor) as SimpleActor } - private val codeReviewer by lazy { getActor(ActorTypes.CodeReviewer) as SimpleActor } - private val etcActor by lazy { getActor(ActorTypes.EtcCodingActor) as SimpleActor } + private val architectureDiscussionActor by lazy { getActor(ActorTypes.ArchitectureDiscussionActor) as ParsedActor } + private val htmlActor by lazy { getActor(ActorTypes.HtmlCodingActor) as SimpleActor } + private val imageActor by lazy { getActor(ActorTypes.ImageActor) as ImageActor } + private val javascriptActor by lazy { getActor(ActorTypes.JavascriptCodingActor) as SimpleActor } + private val cssActor by lazy { getActor(ActorTypes.CssCodingActor) as SimpleActor } + private val codeReviewer by lazy { getActor(ActorTypes.CodeReviewer) as SimpleActor } + private val etcActor by lazy { getActor(ActorTypes.EtcCodingActor) as SimpleActor } - private val codeFiles = mutableSetOf() + private val codeFiles = mutableSetOf() - fun start( - userMessage: String, - ) { - val task = ui.newTask() - val toInput = { it: String -> listOf(it) } - val architectureResponse = Discussable( - task = task, - userMessage = { userMessage }, - initialResponse = { it: String -> architectureDiscussionActor.answer(toInput(it), api = api) }, - outputFn = { design: ParsedResponse -> - // renderMarkdown("${design.text}\n\n```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```") - AgentPatterns.displayMapInTabs( - mapOf( - "Text" to renderMarkdown(design.text, ui = ui), - "JSON" to renderMarkdown( - "```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```", - ui = ui - ), - ) - ) - }, - ui = ui, - reviseResponse = { userMessages: List> -> - architectureDiscussionActor.respond( - messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray()), - input = toInput(userMessage), - api = api - ) - }, - atomicRef = AtomicReference(), - semaphore = Semaphore(0), - heading = userMessage - ).call() + fun start( + userMessage: String, + ) { + val task = ui.newTask() + val toInput = { it: String -> listOf(it) } + val architectureResponse = Discussable( + task = task, + userMessage = { userMessage }, + initialResponse = { it: String -> architectureDiscussionActor.answer(toInput(it), api = api) }, + outputFn = { design: ParsedResponse -> + // renderMarkdown("${design.text}\n\n```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```") + AgentPatterns.displayMapInTabs( + mapOf( + "Text" to renderMarkdown(design.text, ui = ui), + "JSON" to renderMarkdown( + "```json\n${JsonUtil.toJson(design.obj)/*.indent(" ")*/}\n```", + ui = ui + ), + ) + ) + }, + ui = ui, + reviseResponse = { userMessages: List> -> + architectureDiscussionActor.respond( + messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray()), + input = toInput(userMessage), + api = api + ) + }, + atomicRef = AtomicReference(), + semaphore = Semaphore(0), + heading = userMessage + ).call() - try { + try { // val toolSpecs = tools.map { ToolServlet.tools.find { t -> t.path == it } } // .joinToString("\n\n") { it?.let { JsonUtil.toJson(it.openApiDescription) } ?: "" } // var messageWithTools = userMessage // if (toolSpecs.isNotBlank()) messageWithTools += "\n\nThese services are available:\n$toolSpecs" - task.echo( - renderMarkdown( - "```json\n${JsonUtil.toJson(architectureResponse.obj)/*.indent(" ")*/}\n```", - ui = ui - ) - ) - val fileTabs = TabbedDisplay(task) - architectureResponse.obj.files.filter { - !it.name!!.startsWith("http") - }.map { (path, description) -> - val task = ui.newTask(false).apply { fileTabs[path.toString()] = placeholder } - task.header("Drafting $path") - codeFiles.add(File(path).toPath()) - pool.submit { - when (path!!.split(".").last().lowercase()) { + task.echo( + renderMarkdown( + "```json\n${JsonUtil.toJson(architectureResponse.obj)/*.indent(" ")*/}\n```", + ui = ui + ) + ) + val fileTabs = TabbedDisplay(task) + architectureResponse.obj.files.filter { + !it.name!!.startsWith("http") + }.map { (path, description) -> + val task = ui.newTask(false).apply { fileTabs[path.toString()] = placeholder } + task.header("Drafting $path") + codeFiles.add(File(path).toPath()) + pool.submit { + when (path!!.split(".").last().lowercase()) { - "js" -> draftResourceCode( - task = task, - request = javascriptActor.chatMessages( - listOf( + "js" -> draftResourceCode( + task = task, + request = javascriptActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = javascriptActor, - path = File(path).toPath(), "js", "javascript" - ) + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = javascriptActor, + path = File(path).toPath(), "js", "javascript" + ) - "css" -> draftResourceCode( - task = task, - request = cssActor.chatMessages( - listOf( + "css" -> draftResourceCode( + task = task, + request = cssActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = cssActor, - path = File(path).toPath() - ) + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = cssActor, + path = File(path).toPath() + ) - "html" -> draftResourceCode( - task = task, - request = htmlActor.chatMessages( - listOf( + "html" -> draftResourceCode( + task = task, + request = htmlActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = htmlActor, - path = File(path).toPath() - ) + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = htmlActor, + path = File(path).toPath() + ) - "png" -> draftImage( - task = task, - request = etcActor.chatMessages( - listOf( + "png" -> draftImage( + task = task, + request = etcActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = imageActor, - path = File(path).toPath() - ) + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = imageActor, + path = File(path).toPath() + ) - "jpg" -> draftImage( - task = task, - request = etcActor.chatMessages( - listOf( + "jpg" -> draftImage( + task = task, + request = etcActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = imageActor, - path = File(path).toPath() - ) + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = imageActor, + path = File(path).toPath() + ) - else -> draftResourceCode( - task = task, - request = etcActor.chatMessages( - listOf( + else -> draftResourceCode( + task = task, + request = etcActor.chatMessages( + listOf( // messageWithTools, - architectureResponse.text, - "Render $path - $description" - ) - ), - actor = etcActor, - path = File(path).toPath() - ) - - } - } - }.toTypedArray().forEach { it.get() } - // Apply codeReviewer + architectureResponse.text, + "Render $path - $description" + ) + ), + actor = etcActor, + path = File(path).toPath() + ) - iterateCode(task) - } catch (e: Throwable) { - log.warn("Error", e) - task.error(ui, e) + } } - } + }.toTypedArray().forEach { it.get() } + // Apply codeReviewer - fun codeSummary() = codeFiles.filter { - if (it.name.lowercase().endsWith(".png")) return@filter false - if (it.name.lowercase().endsWith(".jpg")) return@filter false - true - }.joinToString("\n\n") { path -> - "# $path\n```${path.toString().split('.').last()}\n${root.resolve(path.toFile()).readText()}\n```" + iterateCode(task) + } catch (e: Throwable) { + log.warn("Error", e) + task.error(ui, e) } + } - private fun iterateCode( - task: SessionTask - ) { - Discussable( - task = task, - heading = "Code Refinement", - userMessage = { codeSummary() }, - initialResponse = { - codeReviewer.answer(listOf(it), api = api) - }, - outputFn = { code -> - renderMarkdown( - ui.socketManager!!.addApplyFileDiffLinks( - root = root.toPath(), - response = code, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - task.complete("$path Updated") - } - }, - ui = ui, - api = api, - model = model, - ) - ) + fun codeSummary() = codeFiles.filter { + if (it.name.lowercase().endsWith(".png")) return@filter false + if (it.name.lowercase().endsWith(".jpg")) return@filter false + true + }.joinToString("\n\n") { path -> + "# $path\n```${path.toString().split('.').last()}\n${root.resolve(path.toFile()).readText()}\n```" + } + + private fun iterateCode( + task: SessionTask + ) { + Discussable( + task = task, + heading = "Code Refinement", + userMessage = { codeSummary() }, + initialResponse = { + codeReviewer.answer(listOf(it), api = api) + }, + outputFn = { code -> + renderMarkdown( + ui.socketManager!!.addApplyFileDiffLinks( + root = root.toPath(), + response = code, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") + } }, ui = ui, - reviseResponse = { userMessages -> - val userMessages = userMessages.toMutableList() - userMessages.set(0, userMessages.get(0).copy(first = codeSummary())) - val combinedMessages = - userMessages.map { ApiModel.ChatMessage(Role.user, it.first.toContentList()) } - codeReviewer.respond( - input = listOf(element = combinedMessages.joinToString("\n")), - api = api, - messages = combinedMessages.toTypedArray(), - ) - }, - ).call() - } + api = api, + model = model, + ) + ) + }, + ui = ui, + reviseResponse = { userMessages -> + val userMessages = userMessages.toMutableList() + userMessages.set(0, userMessages.get(0).copy(first = codeSummary())) + val combinedMessages = + userMessages.map { ApiModel.ChatMessage(Role.user, it.first.toContentList()) } + codeReviewer.respond( + input = listOf(element = combinedMessages.joinToString("\n")), + api = api, + messages = combinedMessages.toTypedArray(), + ) + }, + ).call() + } - private fun draftImage( - task: SessionTask, - request: Array, - actor: ImageActor, - path: Path, - ) { - try { - var code = Discussable( - task = task, - userMessage = { "" }, - heading = "Drafting $path", - initialResponse = { - val messages = (request + ApiModel.ChatMessage(Role.user, "Draft $path".toContentList())) - .toList().toTypedArray() - actor.respond( - listOf(request.joinToString("\n") { it.content?.joinToString() ?: "" }), - api, - *messages - ) + private fun draftImage( + task: SessionTask, + request: Array, + actor: ImageActor, + path: Path, + ) { + try { + var code = Discussable( + task = task, + userMessage = { "" }, + heading = "Drafting $path", + initialResponse = { + val messages = (request + ApiModel.ChatMessage(Role.user, "Draft $path".toContentList())) + .toList().toTypedArray() + actor.respond( + listOf(request.joinToString("\n") { it.content?.joinToString() ?: "" }), + api, + *messages + ) - }, - outputFn = { img -> - renderMarkdown( - "", ui = ui - ) - }, - ui = ui, - reviseResponse = { userMessages: List> -> - actor.respond( - messages = (request.toList() + userMessages.map { - ApiModel.ChatMessage( - it.second, - it.first.toContentList() - ) - }) - .toTypedArray(), - input = listOf(element = (request.toList() + userMessages.map { - ApiModel.ChatMessage( - it.second, - it.first.toContentList() - ) - }) - .joinToString("\n") { it.content?.joinToString() ?: "" }), - api = api, - ) - }, - ).call() - task.complete( - renderMarkdown( - "", ui = ui - ) - ) - } catch (e: Throwable) { - val error = task.error(ui, e) - task.complete(ui.hrefLink("♻", "href-link regen-button") { - error?.clear() - draftImage(task, request, actor, path) + }, + outputFn = { img -> + renderMarkdown( + "", ui = ui + ) + }, + ui = ui, + reviseResponse = { userMessages: List> -> + actor.respond( + messages = (request.toList() + userMessages.map { + ApiModel.ChatMessage( + it.second, + it.first.toContentList() + ) }) - } - } - - private fun write( - code: ImageResponse, - path: Path - ): ByteArray { - val byteArrayOutputStream = ByteArrayOutputStream() - ImageIO.write( - code.image, - path.toString().split(".").last(), - byteArrayOutputStream + .toTypedArray(), + input = listOf(element = (request.toList() + userMessages.map { + ApiModel.ChatMessage( + it.second, + it.first.toContentList() + ) + }) + .joinToString("\n") { it.content?.joinToString() ?: "" }), + api = api, + ) + }, + ).call() + task.complete( + renderMarkdown( + "", ui = ui ) - val bytes = byteArrayOutputStream.toByteArray() - return bytes + ) + } catch (e: Throwable) { + val error = task.error(ui, e) + task.complete(ui.hrefLink("♻", "href-link regen-button") { + error?.clear() + draftImage(task, request, actor, path) + }) } + } - private fun draftResourceCode( - task: SessionTask, - request: Array, - actor: SimpleActor, - path: Path, - vararg languages: String = arrayOf(path.toString().split(".").last().lowercase()), - ) { - try { - var code = Discussable( - task = task, - userMessage = { "Drafting $path" }, - heading = "", - initialResponse = { - actor.respond( - listOf(request.joinToString("\n") { it.content?.joinToString() ?: "" }), - api, - *(request + ApiModel.ChatMessage(Role.user, "Draft $path".toContentList())) - .toList().toTypedArray() - ) - }, - outputFn = { design: String -> - var design = design - languages.forEach { language -> - if (design.contains("```$language")) { - design = design.substringAfter("```$language").substringBefore("```") - } - } - renderMarkdown("```${languages.first()}\n${design.let { it }}\n```", ui = ui) - }, - ui = ui, - reviseResponse = { userMessages: List> -> - actor.respond( - messages = (request.toList() + userMessages.map { - ApiModel.ChatMessage( - it.second, - it.first.toContentList() - ) - }) - .toTypedArray(), - input = listOf(element = (request.toList() + userMessages.map { - ApiModel.ChatMessage( - it.second, - it.first.toContentList() - ) - }) - .joinToString("\n") { it.content?.joinToString() ?: "" }), - api = api, - ) - }, - ).call() - code = extractCode(code) - task.complete( - "$path Updated" - ) - } catch (e: Throwable) { - val error = task.error(ui, e) - task.complete(ui.hrefLink("♻", "href-link regen-button") { - error?.clear() - draftResourceCode(task, request, actor, path, *languages) + private fun write( + code: ImageResponse, + path: Path + ): ByteArray { + val byteArrayOutputStream = ByteArrayOutputStream() + ImageIO.write( + code.image, + path.toString().split(".").last(), + byteArrayOutputStream + ) + val bytes = byteArrayOutputStream.toByteArray() + return bytes + } + + private fun draftResourceCode( + task: SessionTask, + request: Array, + actor: SimpleActor, + path: Path, + vararg languages: String = arrayOf(path.toString().split(".").last().lowercase()), + ) { + try { + var code = Discussable( + task = task, + userMessage = { "Drafting $path" }, + heading = "", + initialResponse = { + actor.respond( + listOf(request.joinToString("\n") { it.content?.joinToString() ?: "" }), + api, + *(request + ApiModel.ChatMessage(Role.user, "Draft $path".toContentList())) + .toList().toTypedArray() + ) + }, + outputFn = { design: String -> + var design = design + languages.forEach { language -> + if (design.contains("```$language")) { + design = design.substringAfter("```$language").substringBefore("```") + } + } + renderMarkdown("```${languages.first()}\n${design.let { it }}\n```", ui = ui) + }, + ui = ui, + reviseResponse = { userMessages: List> -> + actor.respond( + messages = (request.toList() + userMessages.map { + ApiModel.ChatMessage( + it.second, + it.first.toContentList() + ) }) - } + .toTypedArray(), + input = listOf(element = (request.toList() + userMessages.map { + ApiModel.ChatMessage( + it.second, + it.first.toContentList() + ) + }) + .joinToString("\n") { it.content?.joinToString() ?: "" }), + api = api, + ) + }, + ).call() + code = extractCode(code) + task.complete( + "$path Updated" + ) + } catch (e: Throwable) { + val error = task.error(ui, e) + task.complete(ui.hrefLink("♻", "href-link regen-button") { + error?.clear() + draftResourceCode(task, request, actor, path, *languages) + }) } + } - private fun extractCode(code: String): String { - var code = code - code = code.trim() - "(?s)```[^\\n]*\n(.*)\n```".toRegex().find(code)?.let { - code = it.groupValues[1] - } - return code + private fun extractCode(code: String): String { + var code = code + code = code.trim() + "(?s)```[^\\n]*\n(.*)\n```".toRegex().find(code)?.let { + code = it.groupValues[1] } + return code + } - companion object { - val log = org.slf4j.LoggerFactory.getLogger(WebDevAgent::class.java) - - data class ProjectSpec( - @Description("Files in the project design, including all local html, css, and js files.") - val files: List = emptyList() - ) : ValidatedObject { - override fun validate(): String? = when { - files.isEmpty() -> "Resources are required" - files.any { it.validate() != null } -> "Invalid resource" - else -> null - } - } + companion object { + val log = org.slf4j.LoggerFactory.getLogger(WebDevAgent::class.java) - data class ProjectFile( - @Description("The path to the file, relative to the project root.") - val name: String? = "", - @Description("A brief description of the file's purpose and contents.") - val description: String? = "" - ) : ValidatedObject { - override fun validate(): String? = when { - name.isNullOrBlank() -> "Path is required" - name.contains(" ") -> "Path cannot contain spaces" - !name.contains(".") -> "Path must contain a file extension" - else -> null - } - } + data class ProjectSpec( + @Description("Files in the project design, including all local html, css, and js files.") + val files: List = emptyList() + ) : ValidatedObject { + override fun validate(): String? = when { + files.isEmpty() -> "Resources are required" + files.any { it.validate() != null } -> "Invalid resource" + else -> null + } + } + data class ProjectFile( + @Description("The path to the file, relative to the project root.") + val name: String? = "", + @Description("A brief description of the file's purpose and contents.") + val description: String? = "" + ) : ValidatedObject { + override fun validate(): String? = when { + name.isNullOrBlank() -> "Path is required" + name.contains(" ") -> "Path cannot contain spaces" + !name.contains(".") -> "Path must contain a file extension" + else -> null + } } + + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ActorDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ActorDesigner.kt index f080762b..c81d0397 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ActorDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ActorDesigner.kt @@ -5,24 +5,24 @@ import com.simiacryptus.jopenai.models.OpenAIModels import com.simiacryptus.skyenet.core.actors.ParsedActor class ActorDesigner( - model: ChatModel, - temperature: Double + model: ChatModel, + temperature: Double ) : ParsedActor( - resultClass = AgentActorDesign::class.java, - exampleInstance = AgentActorDesign( - actors = listOf( - ActorDesign( - name = "Actor 1", - description = "Actor 1 description", - type = "Simple", - resultClass = "String", - ) - ) - ), - model = model, - temperature = temperature, - parsingModel = OpenAIModels.GPT4oMini, - prompt = """ + resultClass = AgentActorDesign::class.java, + exampleInstance = AgentActorDesign( + actors = listOf( + ActorDesign( + name = "Actor 1", + description = "Actor 1 description", + type = "Simple", + resultClass = "String", + ) + ) + ), + model = model, + temperature = temperature, + parsingModel = OpenAIModels.GPT4oMini, + prompt = """ You are an AI actor designer. Your task is to expand on a high-level design with requirements for each actor. diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/CodingActorDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/CodingActorDesigner.kt index 3311f4fc..327ef80e 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/CodingActorDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/CodingActorDesigner.kt @@ -7,14 +7,14 @@ import com.simiacryptus.skyenet.interpreter.Interpreter import kotlin.reflect.KClass class CodingActorDesigner( - interpreterClass: KClass, - symbols: Map, - model: ChatModel, - temperature: Double + interpreterClass: KClass, + symbols: Map, + model: ChatModel, + temperature: Double ) : CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - details = """ + interpreterClass = interpreterClass, + symbols = symbols, + details = """ | |Your task is to design a system that uses gpt "actors" to form a "community" of actors interacting to solve problems. |Your task is to implement a "script" or "coding" actor that takes part in a larger system. @@ -68,11 +68,11 @@ class CodingActorDesigner( |DO NOT subclass the CodingActor class. Use the constructor directly within the function. | """.trimMargin().trim(), - model = model, - temperature = temperature, + model = model, + temperature = temperature, ) { - init { - evalFormat = false - codeInterceptor = { fixups(it) } - } + init { + evalFormat = false + codeInterceptor = { fixups(it) } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/DetailDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/DetailDesigner.kt index e1042e32..443c76b6 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/DetailDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/DetailDesigner.kt @@ -5,43 +5,43 @@ import com.simiacryptus.jopenai.models.OpenAIModels import com.simiacryptus.skyenet.core.actors.ParsedActor class DetailDesigner( - model: ChatModel, - temperature: Double + model: ChatModel, + temperature: Double ) : ParsedActor( - resultClass = AgentFlowDesign::class.java, - exampleInstance = AgentFlowDesign( - name = "TextAnalyzer", - description = "Analyze input text for sentiment and key topics", - mainInput = DataInfo( + resultClass = AgentFlowDesign::class.java, + exampleInstance = AgentFlowDesign( + name = "TextAnalyzer", + description = "Analyze input text for sentiment and key topics", + mainInput = DataInfo( + type = "String", + description = "raw text" + ), + logicFlow = LogicFlow( + items = listOf( + LogicFlowItem( + name = "Preprocess text", + description = "Preprocess text (remove noise, normalize)", + actors = listOf( + "TextPreprocessor" + ), + inputs = listOf( + DataInfo( + type = "String", + description = "raw text" + ) + ), + output = DataInfo( type = "String", - description = "raw text" + description = "preprocessed text" + ) ), - logicFlow = LogicFlow( - items = listOf( - LogicFlowItem( - name = "Preprocess text", - description = "Preprocess text (remove noise, normalize)", - actors = listOf( - "TextPreprocessor" - ), - inputs = listOf( - DataInfo( - type = "String", - description = "raw text" - ) - ), - output = DataInfo( - type = "String", - description = "preprocessed text" - ) - ), - ) - ) - ), - model = model, - temperature = temperature, - parsingModel = OpenAIModels.GPT4o, - prompt = """ + ) + ) + ), + model = model, + temperature = temperature, + parsingModel = OpenAIModels.GPT4o, + prompt = """ You are an expert detailed software designer specializing in AI agent systems. Your task is to expand on the high-level architecture and design a detailed "agent" system that uses GPT "actors" to model a creative process. diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/FlowStepDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/FlowStepDesigner.kt index 8bc6fad8..319b7944 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/FlowStepDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/FlowStepDesigner.kt @@ -7,14 +7,14 @@ import org.slf4j.LoggerFactory import kotlin.reflect.KClass class FlowStepDesigner( - interpreterClass: KClass, - symbols: Map, - model: ChatModel, - temperature: Double + interpreterClass: KClass, + symbols: Map, + model: ChatModel, + temperature: Double ) : CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - details = """ + interpreterClass = interpreterClass, + symbols = symbols, + details = """ |You are a software implementor. | |Your task is to implement logic for an "agent" system that uses gpt "actors" to construct a model of a creative process. @@ -82,21 +82,21 @@ class FlowStepDesigner( |**IMPORTANT**: Do not redefine any symbol defined in the preceding code messages. | |""".trimMargin().trim(), - model = model, - temperature = temperature, - runtimeSymbols = mapOf( - "log" to log - ), + model = model, + temperature = temperature, + runtimeSymbols = mapOf( + "log" to log + ), ) { - init { - evalFormat = false - codeInterceptor = { fixups(it) } - } + init { + evalFormat = false + codeInterceptor = { fixups(it) } + } - companion object { - private val log = LoggerFactory.getLogger(FlowStepDesigner::class.java) - fun fixups(it: String) = it - .replace("ChatModels.GPT_3_5_TURBO", "OpenAIModels.GPT35Turbo") - .replace("OpenAIModels.DallE3", "ImageModels.DallE3") - } + companion object { + private val log = LoggerFactory.getLogger(FlowStepDesigner::class.java) + fun fixups(it: String) = it + .replace("ChatModels.GPT_3_5_TURBO", "OpenAIModels.GPT35Turbo") + .replace("OpenAIModels.DallE3", "ImageModels.DallE3") + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/HighLevelDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/HighLevelDesigner.kt index c5ff4a3a..60d5ab23 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/HighLevelDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/HighLevelDesigner.kt @@ -4,12 +4,12 @@ import com.simiacryptus.jopenai.models.ChatModel import com.simiacryptus.skyenet.core.actors.SimpleActor class HighLevelDesigner( - model: ChatModel, - temperature: Double + model: ChatModel, + temperature: Double ) : SimpleActor( - model = model, - temperature = temperature, - prompt = """ + model = model, + temperature = temperature, + prompt = """ You are an expert high-level software architect specializing in AI-based automated assistants. Your task is to gather requirements and create a detailed design based on the user's idea. Follow these steps: diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ImageActorDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ImageActorDesigner.kt index 040c621c..7815493d 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ImageActorDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ImageActorDesigner.kt @@ -7,14 +7,14 @@ import com.simiacryptus.skyenet.interpreter.Interpreter import kotlin.reflect.KClass class ImageActorDesigner( - interpreterClass: KClass, - symbols: Map, - model: ChatModel, - temperature: Double + interpreterClass: KClass, + symbols: Map, + model: ChatModel, + temperature: Double ) : CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - details = """ + interpreterClass = interpreterClass, + symbols = symbols, + details = """ | |You are a software implementation assistant. |Your task is to implement a "image" actor that takes part in a larger system. @@ -52,11 +52,11 @@ class ImageActorDesigner( |DO NOT subclass the ImageActor class. Use the constructor directly within the function. | """.trimMargin().trim(), - model = model, - temperature = temperature, + model = model, + temperature = temperature, ) { - init { - evalFormat = false - codeInterceptor = { fixups(it) } - } + init { + evalFormat = false + codeInterceptor = { fixups(it) } + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/MetaAgentApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/MetaAgentApp.kt index aa079761..15ed3d48 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/MetaAgentApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/MetaAgentApp.kt @@ -41,17 +41,17 @@ import java.util.* import kotlin.reflect.KClass open class MetaAgentApp( - applicationName: String = "Meta-Agent-Agent v1.1", + applicationName: String = "Meta-Agent-Agent v1.1", ) : ApplicationServer( - applicationName = applicationName, - path = "/meta_agent", + applicationName = applicationName, + path = "/meta_agent", ) { - override val description: String - @Language("Markdown") - get() = "
${ + override val description: String + @Language("Markdown") + get() = "
${ - renderMarkdown( - """ + renderMarkdown( + """ **It's agents all the way down!** Welcome to the MetaAgentAgent, an innovative tool designed to streamline the process of creating custom AI agents. This powerful system leverages the capabilities of OpenAI's language models to assist you in designing and implementing your very own AI agent tailored to your specific needs and preferences. @@ -64,105 +64,106 @@ open class MetaAgentApp( Get started with MetaAgentAgent today and bring your custom AI agent to life with ease! Whether you're looking to automate customer service, streamline data analysis, or create an interactive chatbot, MetaAgentAgent is here to help you make it happen. """.trimIndent() - ) - }
" + ) + }
" - data class Settings( - val model: ChatModel = OpenAIModels.GPT4o, - val validateCode: Boolean = true, - val temperature: Double = 0.2, - val budget: Double = 2.0, - ) + data class Settings( + val model: ChatModel = OpenAIModels.GPT4o, + val validateCode: Boolean = true, + val temperature: Double = 0.2, + val budget: Double = 2.0, + ) - override val settingsClass: Class<*> get() = Settings::class.java - @Suppress("UNCHECKED_CAST") - override fun initSettings(session: Session): T? = Settings() as T - - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - ui: ApplicationInterface, - api: API - ) { - val task = ui.newTask() - task.add("User Message Processing") - try { - val settings = getSettings(session, user) - val agent = MetaAgentAgent( - user = user, - session = session, - dataStorage = dataStorage, - api = api, - ui = ui, - model = settings?.model ?: OpenAIModels.GPT4oMini, - autoEvaluate = settings?.validateCode ?: true, - temperature = settings?.temperature ?: 0.3, - ) - try { - agent.buildAgent(userMessage = userMessage) - } catch (e: SocketTimeoutException) { - log.error("Network timeout during agent building", e) - task.add("The operation timed out. Please check your network connection and try again.") - return - } catch (e: IOException) { - log.error("I/O error during agent building", e) - task.add("An I/O error occurred. Please try again later.") - return - } catch (e: Exception) { - log.error("Unexpected error during agent building", e) - task.add("An unexpected error occurred. Please try again later.") - return - } - task.complete() - } catch (e: Throwable) { - log.error("Error in userMessage", e) - task.error(ui, e) - when (e) { - is IllegalArgumentException -> task.add("Invalid input: ${e.message}") - is IllegalStateException -> task.add("Operation failed: ${e.message}") - else -> task.add("An unexpected error occurred: ${e.message}. Please try again later.") - } - } - } + override val settingsClass: Class<*> get() = Settings::class.java + + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T? = Settings() as T - companion object { - private val log = LoggerFactory.getLogger(MetaAgentApp::class.java) + override fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ) { + val task = ui.newTask() + task.add("User Message Processing") + try { + val settings = getSettings(session, user) + val agent = MetaAgentAgent( + user = user, + session = session, + dataStorage = dataStorage, + api = api, + ui = ui, + model = settings?.model ?: OpenAIModels.GPT4oMini, + autoEvaluate = settings?.validateCode ?: true, + temperature = settings?.temperature ?: 0.3, + ) + try { + agent.buildAgent(userMessage = userMessage) + } catch (e: SocketTimeoutException) { + log.error("Network timeout during agent building", e) + task.add("The operation timed out. Please check your network connection and try again.") + return + } catch (e: IOException) { + log.error("I/O error during agent building", e) + task.add("An I/O error occurred. Please try again later.") + return + } catch (e: Exception) { + log.error("Unexpected error during agent building", e) + task.add("An unexpected error occurred. Please try again later.") + return + } + task.complete() + } catch (e: Throwable) { + log.error("Error in userMessage", e) + task.error(ui, e) + when (e) { + is IllegalArgumentException -> task.add("Invalid input: ${e.message}") + is IllegalStateException -> task.add("Operation failed: ${e.message}") + else -> task.add("An unexpected error occurred: ${e.message}. Please try again later.") + } } + } + + companion object { + private val log = LoggerFactory.getLogger(MetaAgentApp::class.java) + } } open class MetaAgentAgent( - user: User?, - session: Session, - dataStorage: StorageInterface, - val ui: ApplicationInterface, - val api: API, - model: ChatModel = OpenAIModels.GPT4oMini, - var autoEvaluate: Boolean = true, - temperature: Double = 0.3, + user: User?, + session: Session, + dataStorage: StorageInterface, + val ui: ApplicationInterface, + val api: API, + model: ChatModel = OpenAIModels.GPT4oMini, + var autoEvaluate: Boolean = true, + temperature: Double = 0.3, ) : PoolSystem( - dataStorage, user, session + dataStorage, user, session ) { - private val highLevelDesigner by lazy { HighLevelDesigner(model, temperature) } - private val detailDesigner by lazy { DetailDesigner(model, temperature) } - private val interpreterClass: KClass = KotlinInterpreter::class - - val symbols = mapOf( - "ui" to ui, - "api" to api, - "pool" to ApplicationServices.clientManager.getPool(session, user), - ) - private val actorDesigner by lazy { ActorDesigner(model, temperature) } - private val simpleActorDesigner by lazy { SimpleActorDesigner(interpreterClass, symbols, model, temperature) } - private val imageActorDesigner by lazy { ImageActorDesigner(interpreterClass, symbols, model, temperature) } - private val parsedActorDesigner by lazy { ParsedActorDesigner(interpreterClass, symbols, model, temperature) } - private val codingActorDesigner by lazy { CodingActorDesigner(interpreterClass, symbols, model, temperature) } - private val flowStepDesigner by lazy { FlowStepDesigner(interpreterClass, symbols, model, temperature) } - - @Language("kotlin") - val standardImports = """ + private val highLevelDesigner by lazy { HighLevelDesigner(model, temperature) } + private val detailDesigner by lazy { DetailDesigner(model, temperature) } + private val interpreterClass: KClass = KotlinInterpreter::class + + val symbols = mapOf( + "ui" to ui, + "api" to api, + "pool" to ApplicationServices.clientManager.getPool(session, user), + ) + private val actorDesigner by lazy { ActorDesigner(model, temperature) } + private val simpleActorDesigner by lazy { SimpleActorDesigner(interpreterClass, symbols, model, temperature) } + private val imageActorDesigner by lazy { ImageActorDesigner(interpreterClass, symbols, model, temperature) } + private val parsedActorDesigner by lazy { ParsedActorDesigner(interpreterClass, symbols, model, temperature) } + private val codingActorDesigner by lazy { CodingActorDesigner(interpreterClass, symbols, model, temperature) } + private val flowStepDesigner by lazy { FlowStepDesigner(interpreterClass, symbols, model, temperature) } + + @Language("kotlin") + val standardImports = """ |import com.simiacryptus.jopenai.API |import com.simiacryptus.jopenai.models.ChatModels |import com.simiacryptus.skyenet.core.actors.BaseActor @@ -184,32 +185,32 @@ open class MetaAgentAgent( |import javax.imageio.ImageIO """.trimMargin() - fun buildAgent(userMessage: String) { - val design = initialDesign(userMessage) - val actImpls = implementActors(userMessage, design) - val flowImpl = getFlowStepCode(userMessage, design, actImpls) - val mainImpl = getMainFunction(userMessage, design, actImpls, flowImpl) - buildFinalCode(actImpls, flowImpl, mainImpl, design) - } + fun buildAgent(userMessage: String) { + val design = initialDesign(userMessage) + val actImpls = implementActors(userMessage, design) + val flowImpl = getFlowStepCode(userMessage, design, actImpls) + val mainImpl = getMainFunction(userMessage, design, actImpls, flowImpl) + buildFinalCode(actImpls, flowImpl, mainImpl, design) + } - private fun buildFinalCode( - actImpls: Map, - flowImpl: Map, - mainImpl: String, design: ParsedResponse - ) { - val task = ui.newTask() - task.add("Building Final Code") - try { - task.header("Final Code") + private fun buildFinalCode( + actImpls: Map, + flowImpl: Map, + mainImpl: String, design: ParsedResponse + ) { + val task = ui.newTask() + task.add("Building Final Code") + try { + task.header("Final Code") - val imports = (actImpls.values + flowImpl.values + listOf(mainImpl)).flatMap { it.imports() }.toSortedSet() - .joinToString("\n") + val imports = (actImpls.values + flowImpl.values + listOf(mainImpl)).flatMap { it.imports() }.toSortedSet() + .joinToString("\n") - val classBaseName = (design.obj.name?.pascalCase() ?: "MyAgent").replace("[^A-Za-z0-9]".toRegex(), "") + val classBaseName = (design.obj.name?.pascalCase() ?: "MyAgent").replace("[^A-Za-z0-9]".toRegex(), "") - val actorInits = design.obj.actors?.joinToString("\n") { actImpls[it.name] ?: "" } ?: "" + val actorInits = design.obj.actors?.joinToString("\n") { actImpls[it.name] ?: "" } ?: "" - @Language("kotlin") val appCode = """ + @Language("kotlin") val appCode = """ |$standardImports | |$imports @@ -259,7 +260,7 @@ open class MetaAgentAgent( |} """.trimMargin() - @Language("kotlin") val agentCode = """ + @Language("kotlin") val agentCode = """ |$standardImports | |open class ${classBaseName}Agent( @@ -285,464 +286,464 @@ open class MetaAgentAgent( |} """.trimMargin() - //language=MARKDOWN - val code = """ + //language=MARKDOWN + val code = """ |```kotlin |${ - """ + """ |$appCode | |$agentCode """.trimMargin().sortCode() - } + } |``` """.trimMargin() - //language=HTML - task.complete(renderMarkdown(code, ui = ui)) - } catch (e: IOException) { - task.complete("An I/O error occurred. Please try again later.") - } catch (e: Throwable) { - task.error(ui, e) - throw e - } + //language=HTML + task.complete(renderMarkdown(code, ui = ui)) + } catch (e: IOException) { + task.complete("An I/O error occurred. Please try again later.") + } catch (e: Throwable) { + task.error(ui, e) + throw e } - - private fun initialDesign(input: String): ParsedResponse { - val toInput = { it: String -> listOf(it) } - val highLevelDesign = Discussable( - task = ui.newTask(), - userMessage = { input }, - heading = renderMarkdown(input, ui = ui), - initialResponse = { it: String -> highLevelDesigner.answer(toInput(it), api = api) }, - outputFn = { design -> renderMarkdown(design.toString(), ui = ui) }, - ui = ui, - reviseResponse = { userMessages: List> -> - highLevelDesigner.respond( - messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray()), - input = toInput(input), - api = api - ) - }, - ).call() - val toInput1 = { it: String -> listOf(it) } - val flowDesign = Discussable( - task = ui.newTask(), - userMessage = { highLevelDesign }, - heading = "Flow Design", - initialResponse = { it: String -> detailDesigner.answer(toInput1(it), api = api) }, - outputFn = { design: ParsedResponse -> - try { - renderMarkdown( - """ + } + + private fun initialDesign(input: String): ParsedResponse { + val toInput = { it: String -> listOf(it) } + val highLevelDesign = Discussable( + task = ui.newTask(), + userMessage = { input }, + heading = renderMarkdown(input, ui = ui), + initialResponse = { it: String -> highLevelDesigner.answer(toInput(it), api = api) }, + outputFn = { design -> renderMarkdown(design.toString(), ui = ui) }, + ui = ui, + reviseResponse = { userMessages: List> -> + highLevelDesigner.respond( + messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray()), + input = toInput(input), + api = api + ) + }, + ).call() + val toInput1 = { it: String -> listOf(it) } + val flowDesign = Discussable( + task = ui.newTask(), + userMessage = { highLevelDesign }, + heading = "Flow Design", + initialResponse = { it: String -> detailDesigner.answer(toInput1(it), api = api) }, + outputFn = { design: ParsedResponse -> + try { + renderMarkdown( + """ |$design |```json |${JsonUtil.toJson(design.obj)} |``` """.trimMargin(), ui = ui - ) - } catch (e: Throwable) { - renderMarkdown(e.message ?: e.toString(), ui = ui) - } - }, - ui = ui, - reviseResponse = { userMessages: List> -> - detailDesigner.respond( - messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray()), - input = toInput1(highLevelDesign), - api = api - ) - }, - ).call() - val actorDesignParsedResponse: ParsedResponse = Discussable( - task = ui.newTask(), - userMessage = { flowDesign.text }, - heading = "Actor Design", - initialResponse = { it: String -> actorDesigner.answer(listOf(it), api = api) }, - outputFn = { design: ParsedResponse -> - try { - renderMarkdown( - """ + ) + } catch (e: Throwable) { + renderMarkdown(e.message ?: e.toString(), ui = ui) + } + }, + ui = ui, + reviseResponse = { userMessages: List> -> + detailDesigner.respond( + messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray()), + input = toInput1(highLevelDesign), + api = api + ) + }, + ).call() + val actorDesignParsedResponse: ParsedResponse = Discussable( + task = ui.newTask(), + userMessage = { flowDesign.text }, + heading = "Actor Design", + initialResponse = { it: String -> actorDesigner.answer(listOf(it), api = api) }, + outputFn = { design: ParsedResponse -> + try { + renderMarkdown( + """ |$design |```json |${JsonUtil.toJson(design.obj)} |``` """.trimMargin(), ui = ui - ) - } catch (e: Throwable) { - renderMarkdown(e.message ?: e.toString(), ui = ui) - } - }, - ui = ui, - reviseResponse = { userMessages: List> -> - actorDesigner.respond( - messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray()), input = listOf(flowDesign.text), - api = api - ) - }, - ).call() - return object : ParsedResponse(AgentDesign::class.java) { - override val text get() = flowDesign.text + "\n" + actorDesignParsedResponse.text - override val obj - get() = AgentDesign( - name = flowDesign.obj.name, - description = flowDesign.obj.description, - mainInput = flowDesign.obj.mainInput, - logicFlow = flowDesign.obj.logicFlow, - actors = actorDesignParsedResponse.obj.actors, - ) + ) + } catch (e: Throwable) { + renderMarkdown(e.message ?: e.toString(), ui = ui) } + }, + ui = ui, + reviseResponse = { userMessages: List> -> + actorDesigner.respond( + messages = (userMessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray()), input = listOf(flowDesign.text), + api = api + ) + }, + ).call() + return object : ParsedResponse(AgentDesign::class.java) { + override val text get() = flowDesign.text + "\n" + actorDesignParsedResponse.text + override val obj + get() = AgentDesign( + name = flowDesign.obj.name, + description = flowDesign.obj.description, + mainInput = flowDesign.obj.mainInput, + logicFlow = flowDesign.obj.logicFlow, + actors = actorDesignParsedResponse.obj.actors, + ) } - - private fun getMainFunction( - userMessage: String, design: ParsedResponse, - actorImpls: Map, - flowStepCode: Map - ): String { - val task = ui.newTask() - try { - task.header("Main Function") - val codeRequest = CodingActor.CodeRequest( - messages = listOf( - userMessage to Role.user, - design.text to Role.assistant, - "Implement `fun ${design.obj.name?.camelCase()}(${ - listOf(design.obj.mainInput!!) - .joinToString(", ") { (it.name ?: "") + " : " + (it.type ?: "") } - })`" to Role.user - ), codePrefix = ((standardImports.lines() + actorImpls.values + flowStepCode.values) - .joinToString("\n\n") { it.trimIndent() }).sortCode(), - autoEvaluate = autoEvaluate - ) - val mainFunction = execWrap { flowStepDesigner.answer(codeRequest, api = api).code } - task.verbose( - renderMarkdown( - """ + } + + private fun getMainFunction( + userMessage: String, design: ParsedResponse, + actorImpls: Map, + flowStepCode: Map + ): String { + val task = ui.newTask() + try { + task.header("Main Function") + val codeRequest = CodingActor.CodeRequest( + messages = listOf( + userMessage to Role.user, + design.text to Role.assistant, + "Implement `fun ${design.obj.name?.camelCase()}(${ + listOf(design.obj.mainInput!!) + .joinToString(", ") { (it.name ?: "") + " : " + (it.type ?: "") } + })`" to Role.user + ), codePrefix = ((standardImports.lines() + actorImpls.values + flowStepCode.values) + .joinToString("\n\n") { it.trimIndent() }).sortCode(), + autoEvaluate = autoEvaluate + ) + val mainFunction = execWrap { flowStepDesigner.answer(codeRequest, api = api).code } + task.verbose( + renderMarkdown( + """ |```kotlin |$mainFunction |``` """.trimMargin(), ui = ui - ), tag = "div" - ) - task.complete() - return mainFunction - } catch (e: CodingActor.FailedToImplementException) { - task.verbose(e.code ?: throw e) - task.error(ui, e) - return e.code ?: throw e - } catch (e: Throwable) { - task.error(ui, e) - throw e - } + ), tag = "div" + ) + task.complete() + return mainFunction + } catch (e: CodingActor.FailedToImplementException) { + task.verbose(e.code ?: throw e) + task.error(ui, e) + return e.code ?: throw e + } catch (e: Throwable) { + task.error(ui, e) + throw e } - - private fun implementActors( - userMessage: String, - design: ParsedResponse, - ) = design.obj.actors?.map { actorDesign -> - pool.submit> { - val task = ui.newTask() - try { - implementActor(task, actorDesign, userMessage, design) - } catch (e: Throwable) { - task.error(ui, e) - throw e - } - } - }?.toTypedArray()?.associate { it.get() } ?: mapOf() - - private fun implementActor( - task: SessionTask, actorDesign: ActorDesign, - userMessage: String, design: ParsedResponse - ): Pair { - val api = (api as ChatClient).getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } + } + + private fun implementActors( + userMessage: String, + design: ParsedResponse, + ) = design.obj.actors?.map { actorDesign -> + pool.submit> { + val task = ui.newTask() + try { + implementActor(task, actorDesign, userMessage, design) + } catch (e: Throwable) { + task.error(ui, e) + throw e + } + } + }?.toTypedArray()?.associate { it.get() } ?: mapOf() + + private fun implementActor( + task: SessionTask, actorDesign: ActorDesign, + userMessage: String, design: ParsedResponse + ): Pair { + val api = (api as ChatClient).getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } + //language=HTML + task.header("Actor: ${actorDesign.name}") + val type = actorDesign.type + val codeRequest = CodingActor.CodeRequest( + messages = listOf( + userMessage to Role.user, + design.text to Role.assistant, + "Implement `val ${(actorDesign.name).camelCase()} : ${ + when (type.lowercase()) { + "simple" -> "SimpleActor" + "parsed" -> "ParsedActor<${actorDesign.simpleClassName}>" + "coding" -> "CodingActor" + "image" -> "ImageActor" + "tts" -> "TextToSpeechActor" + else -> throw IllegalArgumentException("Unknown actor type: $type") + } + }`" to Role.user + ), autoEvaluate = autoEvaluate, codePrefix = standardImports.sortCode() + ) + var code = "" + val onComplete = java.util.concurrent.Semaphore(0) + Retryable(ui, task) { + try { + val response = execWrap { + when (type.lowercase()) { + "simple" -> simpleActorDesigner.answer(codeRequest, api = api) + "parsed" -> parsedActorDesigner.answer(codeRequest, api = api) + "coding" -> codingActorDesigner.answer(codeRequest, api = api) + "image" -> imageActorDesigner.answer(codeRequest, api = api) + else -> throw IllegalArgumentException("Unknown actor type: $type") + } } - //language=HTML - task.header("Actor: ${actorDesign.name}") - val type = actorDesign.type - val codeRequest = CodingActor.CodeRequest( - messages = listOf( - userMessage to Role.user, - design.text to Role.assistant, - "Implement `val ${(actorDesign.name).camelCase()} : ${ - when (type.lowercase()) { - "simple" -> "SimpleActor" - "parsed" -> "ParsedActor<${actorDesign.simpleClassName}>" - "coding" -> "CodingActor" - "image" -> "ImageActor" - "tts" -> "TextToSpeechActor" - else -> throw IllegalArgumentException("Unknown actor type: $type") - } - }`" to Role.user - ), autoEvaluate = autoEvaluate, codePrefix = standardImports.sortCode() - ) - var code = "" - val onComplete = java.util.concurrent.Semaphore(0) - Retryable(ui, task) { - try { - val response = execWrap { - when (type.lowercase()) { - "simple" -> simpleActorDesigner.answer(codeRequest, api = api) - "parsed" -> parsedActorDesigner.answer(codeRequest, api = api) - "coding" -> codingActorDesigner.answer(codeRequest, api = api) - "image" -> imageActorDesigner.answer(codeRequest, api = api) - else -> throw IllegalArgumentException("Unknown actor type: $type") - } - } - code = response.code - onComplete.release() - renderMarkdown( - """ + code = response.code + onComplete.release() + renderMarkdown( + """ |```kotlin |$code |``` """.trimMargin(), ui = ui - ) - } catch (e: CodingActor.FailedToImplementException) { - task.error(ui, e) - code = e.code ?: "" - renderMarkdown( - """ + ) + } catch (e: CodingActor.FailedToImplementException) { + task.error(ui, e) + code = e.code ?: "" + renderMarkdown( + """ |```kotlin |$code |``` |${ - ui.hrefLink("Accept", classname = "href-link cmd-button") { - autoEvaluate = false - onComplete.release() - } - } - """.trimMargin(), ui = ui - ) + ui.hrefLink("Accept", classname = "href-link cmd-button") { + autoEvaluate = false + onComplete.release() } - } - onComplete.acquire() - //language=HTML - task.complete() - return actorDesign.name to code + } + """.trimMargin(), ui = ui + ) + } } - - private fun execWrap(fn: () -> T): T { - val classLoader = Thread.currentThread().contextClassLoader - val prevCL = KotlinInterpreter.classLoader - KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader - return try { - WebAppClassLoader.runWithServerClassAccess { - require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) - require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) - fn() - } - } finally { - KotlinInterpreter.classLoader = prevCL - } + onComplete.acquire() + //language=HTML + task.complete() + return actorDesign.name to code + } + + private fun execWrap(fn: () -> T): T { + val classLoader = Thread.currentThread().contextClassLoader + val prevCL = KotlinInterpreter.classLoader + KotlinInterpreter.classLoader = classLoader //req.javaClass.classLoader + return try { + WebAppClassLoader.runWithServerClassAccess { + require(null != classLoader.loadClass("org.eclipse.jetty.server.Response")) + require(null != classLoader.loadClass("org.eclipse.jetty.server.Request")) + fn() + } + } finally { + KotlinInterpreter.classLoader = prevCL } - - private fun getFlowStepCode( - userMessage: String, - design: ParsedResponse, - actorImpls: Map, - ): Map { - val flowImpls = HashMap() - design.obj.logicFlow?.items?.forEach { logicFlowItem -> - val message = ui.newTask() - try { - val api = (api as ChatClient).getChildClient().apply { - val createFile = message.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - message.verbose("API log: $this") - } - } - message.header("Logic Flow: ${logicFlowItem.name}") - var code: String? = null - val onComplete = java.util.concurrent.Semaphore(0) - Retryable(ui, message) { - try { - code = execWrap { - flowStepDesigner.answer(CodingActor.CodeRequest(messages = listOf(userMessage to Role.user, - design.text to Role.assistant, - "Implement `fun ${(logicFlowItem.name!!).camelCase()}(${ - logicFlowItem.inputs?.joinToString(", ") { (it.name ?: "") + " : " + (it.type ?: "") } ?: "" - })`" to Role.user), - autoEvaluate = autoEvaluate, - codePrefix = (actorImpls.values + flowImpls.values).joinToString("\n\n") { it.trimIndent() } - .sortCode() - ), api = api - ).code - } - onComplete.release() - renderMarkdown( - """ + } + + private fun getFlowStepCode( + userMessage: String, + design: ParsedResponse, + actorImpls: Map, + ): Map { + val flowImpls = HashMap() + design.obj.logicFlow?.items?.forEach { logicFlowItem -> + val message = ui.newTask() + try { + val api = (api as ChatClient).getChildClient().apply { + val createFile = message.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + message.verbose("API log: $this") + } + } + message.header("Logic Flow: ${logicFlowItem.name}") + var code: String? = null + val onComplete = java.util.concurrent.Semaphore(0) + Retryable(ui, message) { + try { + code = execWrap { + flowStepDesigner.answer(CodingActor.CodeRequest(messages = listOf(userMessage to Role.user, + design.text to Role.assistant, + "Implement `fun ${(logicFlowItem.name!!).camelCase()}(${ + logicFlowItem.inputs?.joinToString(", ") { (it.name ?: "") + " : " + (it.type ?: "") } ?: "" + })`" to Role.user), + autoEvaluate = autoEvaluate, + codePrefix = (actorImpls.values + flowImpls.values).joinToString("\n\n") { it.trimIndent() } + .sortCode() + ), api = api + ).code + } + onComplete.release() + renderMarkdown( + """ |```kotlin |$code |``` """.trimMargin(), ui = ui - ) - } catch (e: CodingActor.FailedToImplementException) { - message.error(ui, e) - code = e.code ?: "" - renderMarkdown( - """ + ) + } catch (e: CodingActor.FailedToImplementException) { + message.error(ui, e) + code = e.code ?: "" + renderMarkdown( + """ |```kotlin |$code |``` |${ - ui.hrefLink("Accept", classname = "href-link cmd-button") { - autoEvaluate = false - onComplete.release() - } - } - """.trimMargin(), ui = ui - ) - } + ui.hrefLink("Accept", classname = "href-link cmd-button") { + autoEvaluate = false + onComplete.release() } - onComplete.acquire() - message.complete() - flowImpls[logicFlowItem.name!!] = code!! - } catch (e: Throwable) { - message.error(ui, e) - throw e - } + } + """.trimMargin(), ui = ui + ) + } } - return flowImpls + onComplete.acquire() + message.complete() + flowImpls[logicFlowItem.name!!] = code!! + } catch (e: Throwable) { + message.error(ui, e) + throw e + } } + return flowImpls + } - companion object + companion object } class MetaAgentActors( - val symbols: Map = mapOf(), - val model: ChatModel = OpenAIModels.GPT4o, - val temperature: Double = 0.3, + val symbols: Map = mapOf(), + val model: ChatModel = OpenAIModels.GPT4o, + val temperature: Double = 0.3, ) { - companion object { - val log = LoggerFactory.getLogger(MetaAgentActors::class.java) - fun T.notIn(vararg examples: T) = !examples.contains(this) - } + companion object { + val log = LoggerFactory.getLogger(MetaAgentActors::class.java) + fun T.notIn(vararg examples: T) = !examples.contains(this) + } } data class AgentFlowDesign( - val name: String? = null, - val description: String? = null, - val mainInput: DataInfo? = null, - val logicFlow: LogicFlow? = null, + val name: String? = null, + val description: String? = null, + val mainInput: DataInfo? = null, + val logicFlow: LogicFlow? = null, ) : ValidatedObject { - override fun validate(): String? = when { - null == logicFlow -> "logicFlow is required" - null != logicFlow.validate() -> logicFlow.validate() - else -> null - } + override fun validate(): String? = when { + null == logicFlow -> "logicFlow is required" + null != logicFlow.validate() -> logicFlow.validate() + else -> null + } } data class AgentDesign( - val name: String? = null, - val path: String? = null, - val description: String? = null, - val mainInput: DataInfo? = null, - val logicFlow: LogicFlow? = null, - val actors: List? = null, + val name: String? = null, + val path: String? = null, + val description: String? = null, + val mainInput: DataInfo? = null, + val logicFlow: LogicFlow? = null, + val actors: List? = null, ) : ValidatedObject { - override fun validate(): String? = when { - null == logicFlow -> "logicFlow is required" - null == actors -> "actors is required" - actors.isEmpty() -> "actors is required" - null != logicFlow.validate() -> logicFlow.validate() - !actors.all { null == it.validate() } -> actors.map { it.validate() }.filter { null != it }.joinToString("\n") - - else -> null - } + override fun validate(): String? = when { + null == logicFlow -> "logicFlow is required" + null == actors -> "actors is required" + actors.isEmpty() -> "actors is required" + null != logicFlow.validate() -> logicFlow.validate() + !actors.all { null == it.validate() } -> actors.map { it.validate() }.filter { null != it }.joinToString("\n") + + else -> null + } } data class AgentActorDesign( - val actors: List? = null, + val actors: List? = null, ) : ValidatedObject { - override fun validate(): String? = when { - null == actors -> "actors is required" - actors.isEmpty() -> "actors is required" - !actors.all { null == it.validate() } -> actors.map { it.validate() }.filter { null != it }.joinToString("\n") + override fun validate(): String? = when { + null == actors -> "actors is required" + actors.isEmpty() -> "actors is required" + !actors.all { null == it.validate() } -> actors.map { it.validate() }.filter { null != it }.joinToString("\n") - else -> null - } + else -> null + } } data class ActorDesign( - @Description("Java class name of the actor") val name: String = "", - val description: String? = null, - @Description("simple, parsed, image, tts, or coding") val type: String = "", - @Description("Simple actors: string; Image actors: image; Coding actors: code; Text-to-speech actors: mp3; Parsed actors: a simple java class name for the data structure") val resultClass: String = "", + @Description("Java class name of the actor") val name: String = "", + val description: String? = null, + @Description("simple, parsed, image, tts, or coding") val type: String = "", + @Description("Simple actors: string; Image actors: image; Coding actors: code; Text-to-speech actors: mp3; Parsed actors: a simple java class name for the data structure") val resultClass: String = "", ) : ValidatedObject { - val simpleClassName: String get() = resultClass.split(".").last() - override fun validate(): String? = when { - name.isEmpty() -> "name is required" - name.chars().anyMatch { !Character.isJavaIdentifierPart(it) } -> "name must be a valid java identifier" - type.isEmpty() -> "type is required" - type.lowercase().notIn( - "simple", "parsed", "coding", "image", "tts" - ) -> "type must be simple, parsed, coding, tts, or image" - - resultClass.isEmpty() -> "resultType is required" - resultClass.lowercase().notIn( - "string", "code", "image", "mp3" - ) && !validClassName(resultClass) -> "resultType must be string, code, image, mp3, or a valid class name" - - else -> null + val simpleClassName: String get() = resultClass.split(".").last() + override fun validate(): String? = when { + name.isEmpty() -> "name is required" + name.chars().anyMatch { !Character.isJavaIdentifierPart(it) } -> "name must be a valid java identifier" + type.isEmpty() -> "type is required" + type.lowercase().notIn( + "simple", "parsed", "coding", "image", "tts" + ) -> "type must be simple, parsed, coding, tts, or image" + + resultClass.isEmpty() -> "resultType is required" + resultClass.lowercase().notIn( + "string", "code", "image", "mp3" + ) && !validClassName(resultClass) -> "resultType must be string, code, image, mp3, or a valid class name" + + else -> null + } + + private fun validClassName(resultType: String): Boolean { + return when { + resultType.isEmpty() -> false + validClassNamePattern.matches(resultType) -> true + else -> false } + } - private fun validClassName(resultType: String): Boolean { - return when { - resultType.isEmpty() -> false - validClassNamePattern.matches(resultType) -> true - else -> false - } - } - - companion object { - val validClassNamePattern = "[A-Za-z][a-zA-Z0-9_<>.]{3,}".toRegex() - } + companion object { + val validClassNamePattern = "[A-Za-z][a-zA-Z0-9_<>.]{3,}".toRegex() + } } data class LogicFlow( - val items: List? = null, + val items: List? = null, ) : ValidatedObject { - override fun validate(): String? = items?.map { it.validate() }?.firstOrNull { !it.isNullOrBlank() } + override fun validate(): String? = items?.map { it.validate() }?.firstOrNull { !it.isNullOrBlank() } } data class LogicFlowItem( - val name: String? = null, - val description: String? = null, - val actors: List? = null, - @Description("symbol names of variables/values used as input to this step") val inputs: List? = null, - @Description("description of the output of this step") val output: DataInfo? = null, + val name: String? = null, + val description: String? = null, + val actors: List? = null, + @Description("symbol names of variables/values used as input to this step") val inputs: List? = null, + @Description("description of the output of this step") val output: DataInfo? = null, ) : ValidatedObject { - override fun validate(): String? = when { - null == name -> "name is required" - name.isEmpty() -> "name is required" - //inputs?.isEmpty() != false && inputs?.isEmpty() != false -> "inputs is required" - else -> null - } + override fun validate(): String? = when { + null == name -> "name is required" + name.isEmpty() -> "name is required" + //inputs?.isEmpty() != false && inputs?.isEmpty() != false -> "inputs is required" + else -> null + } } data class DataInfo( - val name: String? = null, - val description: String? = null, - val type: String? = null, + val name: String? = null, + val description: String? = null, + val type: String? = null, ) : ValidatedObject { - override fun validate(): String? = when { - null == name -> "name is required" - name.isEmpty() -> "name is required" - null == type -> "type is required" - type.isEmpty() -> "type is required" - else -> null - } + override fun validate(): String? = when { + null == name -> "name is required" + name.isEmpty() -> "name is required" + null == type -> "type is required" + type.isEmpty() -> "type is required" + else -> null + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ParsedActorDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ParsedActorDesigner.kt index 412376e9..edaef383 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ParsedActorDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/ParsedActorDesigner.kt @@ -9,14 +9,14 @@ import com.simiacryptus.skyenet.kotlin.KotlinInterpreter import kotlin.reflect.KClass class ParsedActorDesigner( - interpreterClass: KClass = KotlinInterpreter::class, - symbols: Map = mapOf(), - model: ChatModel = OpenAIModels.GPT4o, - temperature: Double = 0.3, + interpreterClass: KClass = KotlinInterpreter::class, + symbols: Map = mapOf(), + model: ChatModel = OpenAIModels.GPT4o, + temperature: Double = 0.3, ) : CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - details = """ + interpreterClass = interpreterClass, + symbols = symbols, + details = """ | |Your task is to design a system that uses gpt "actors" to form a "community" of actors interacting to solve problems. |Your task is to implement a "parsed" actor that takes part in a larger system. @@ -69,11 +69,11 @@ class ParsedActorDesigner( |DO NOT subclass the ParsedActor class. Use the constructor directly within the function. | """.trimMargin().trim(), - model = model, - temperature = temperature, + model = model, + temperature = temperature, ) { - init { - evalFormat = false - codeInterceptor = { fixups(it) } - } + init { + evalFormat = false + codeInterceptor = { fixups(it) } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/SimpleActorDesigner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/SimpleActorDesigner.kt index d5b54aca..0cd5ced9 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/SimpleActorDesigner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/meta/SimpleActorDesigner.kt @@ -7,16 +7,16 @@ import com.simiacryptus.skyenet.interpreter.Interpreter import kotlin.reflect.KClass class SimpleActorDesigner( - interpreterClass: KClass, - symbols: Map, - model: ChatModel, - temperature: Double + interpreterClass: KClass, + symbols: Map, + model: ChatModel, + temperature: Double ) : CodingActor( - interpreterClass = interpreterClass, - symbols = symbols, - model = model, - temperature = temperature, - details = """ + interpreterClass = interpreterClass, + symbols = symbols, + model = model, + temperature = temperature, + details = """ You are a software implementation assistant. Your task is to implement a "simple" actor that takes part in a larger system. "Simple" actors contain a system directive and can process a list of user messages into a response. @@ -53,8 +53,8 @@ class SimpleActorDesigner( DO NOT subclass the SimpleActor class. Use the constructor directly within the function. """.trimIndent() ) { - init { - evalFormat = false - codeInterceptor = { fixups(it) } - } + init { + evalFormat = false + codeInterceptor = { fixups(it) } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/CodeParsingModel.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/CodeParsingModel.kt index 34f59e1a..965ba263 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/CodeParsingModel.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/CodeParsingModel.kt @@ -6,44 +6,44 @@ import com.simiacryptus.jopenai.models.ChatModel import com.simiacryptus.skyenet.core.actors.ParsedActor open class CodeParsingModel( - private val parsingModel: ChatModel, - private val temperature: Double + private val parsingModel: ChatModel, + private val temperature: Double ) : ParsingModel { - override fun merge( - runningDocument: ParsingModel.DocumentData, - newData: ParsingModel.DocumentData - ): ParsingModel.DocumentData { - val runningDocument = runningDocument as CodeData - val newData = newData as CodeData - return CodeData( - id = newData.id ?: runningDocument.id, - content_list = mergeContent(runningDocument.content_list, newData.content_list).takeIf { it.isNotEmpty() }, - ) - } + override fun merge( + runningDocument: ParsingModel.DocumentData, + newData: ParsingModel.DocumentData + ): ParsingModel.DocumentData { + val runningDocument = runningDocument as CodeData + val newData = newData as CodeData + return CodeData( + id = newData.id ?: runningDocument.id, + content_list = mergeContent(runningDocument.content_list, newData.content_list).takeIf { it.isNotEmpty() }, + ) + } - protected open fun mergeContent( - existingContent: List?, - newContent: List? - ): List { - val mergedContent = (existingContent ?: emptyList()).toMutableList() - (newContent ?: emptyList()).forEach { newItem -> - val existingIndex = mergedContent.indexOfFirst { it.type == newItem.type && it.text?.trim() == newItem.text?.trim() } - if (existingIndex != -1) { - mergedContent[existingIndex] = mergeContentData(mergedContent[existingIndex], newItem) - } else { - mergedContent.add(newItem) - } - } - return mergedContent + protected open fun mergeContent( + existingContent: List?, + newContent: List? + ): List { + val mergedContent = (existingContent ?: emptyList()).toMutableList() + (newContent ?: emptyList()).forEach { newItem -> + val existingIndex = mergedContent.indexOfFirst { it.type == newItem.type && it.text?.trim() == newItem.text?.trim() } + if (existingIndex != -1) { + mergedContent[existingIndex] = mergeContentData(mergedContent[existingIndex], newItem) + } else { + mergedContent.add(newItem) + } } + return mergedContent + } - protected open fun mergeContentData(existing: CodeContent, new: CodeContent) = existing.copy( - content_list = mergeContent(existing.content_list, new.content_list).takeIf { it.isNotEmpty() }, - tags = ((existing.tags ?: emptyList()) + (new.tags ?: emptyList())).distinct().takeIf { it.isNotEmpty() } - ) + protected open fun mergeContentData(existing: CodeContent, new: CodeContent) = existing.copy( + content_list = mergeContent(existing.content_list, new.content_list).takeIf { it.isNotEmpty() }, + tags = ((existing.tags ?: emptyList()) + (new.tags ?: emptyList())).distinct().takeIf { it.isNotEmpty() } + ) - open val promptSuffix = """ + open val promptSuffix = """ Parse the code into a structured format that describes its components: 1. Separate the content into sections, paragraphs, statements, etc. 2. All source content should be included in the output, with paraphrasing, corrections, and context as needed @@ -51,36 +51,36 @@ Parse the code into a structured format that describes its components: 4. Assign relevant tags to each node to improve searchability and categorization. """.trimMargin() - open val exampleInstance = CodeData() + open val exampleInstance = CodeData() - override fun getParser(api: API): (String) -> CodeData { - val parser = ParsedActor( - resultClass = CodeData::class.java, - exampleInstance = exampleInstance, - prompt = "", - parsingModel = parsingModel, - temperature = temperature - ).getParser( - api, promptSuffix = promptSuffix - ) - return { text -> parser.apply(text) } - } + override fun getParser(api: API): (String) -> CodeData { + val parser = ParsedActor( + resultClass = CodeData::class.java, + exampleInstance = exampleInstance, + prompt = "", + parsingModel = parsingModel, + temperature = temperature + ).getParser( + api, promptSuffix = promptSuffix + ) + return { text -> parser.apply(text) } + } - override fun newDocument() = CodeData() + override fun newDocument() = CodeData() - data class CodeData( - @Description("Code identifier") override val id: String? = null, - @Description("Hierarchical structure and data") override val content_list: List? = null, - ) : ParsingModel.DocumentData + data class CodeData( + @Description("Code identifier") override val id: String? = null, + @Description("Hierarchical structure and data") override val content_list: List? = null, + ) : ParsingModel.DocumentData - data class CodeContent( - @Description("Content type, e.g. function, class, comment") override val type: String = "", - @Description("Brief, self-contained text either copied, paraphrased, or summarized") override val text: String? = null, - @Description("Sub-elements") override val content_list: List? = null, - @Description("Tags - related topics and non-entity indexing") override val tags: List? = null - ) : ParsingModel.ContentData + data class CodeContent( + @Description("Content type, e.g. function, class, comment") override val type: String = "", + @Description("Brief, self-contained text either copied, paraphrased, or summarized") override val text: String? = null, + @Description("Sub-elements") override val content_list: List? = null, + @Description("Tags - related topics and non-entity indexing") override val tags: List? = null + ) : ParsingModel.ContentData - companion object { - val log = org.slf4j.LoggerFactory.getLogger(CodeParsingModel::class.java) - } + companion object { + val log = org.slf4j.LoggerFactory.getLogger(CodeParsingModel::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParserApp.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParserApp.kt index 60b6c51f..ce6dcf91 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParserApp.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParserApp.kt @@ -22,140 +22,140 @@ import javax.imageio.ImageIO import kotlin.math.min open class DocumentParserApp( - applicationName: String = "Document Extractor", - path: String = "/pdfExtractor", - val api: API = ChatClient(), - val parsingModel: ParsingModel, - val reader: (File) -> DocumentReader = { - when { - it.name.endsWith(".pdf", ignoreCase = true) -> PDFReader(it) - else -> TextReader(it) - } - }, - val fileInputs: List? = null, + applicationName: String = "Document Extractor", + path: String = "/pdfExtractor", + val api: API = ChatClient(), + val parsingModel: ParsingModel, + val reader: (File) -> DocumentReader = { + when { + it.name.endsWith(".pdf", ignoreCase = true) -> PDFReader(it) + else -> TextReader(it) + } + }, + val fileInputs: List? = null, ) : ApplicationServer( - applicationName = applicationName, - path = path, - showMenubar = true + applicationName = applicationName, + path = path, + showMenubar = true ) { - override val singleInput: Boolean = true - override val stickyInput: Boolean = false + override val singleInput: Boolean = true + override val stickyInput: Boolean = false - override fun newSession(user: User?, session: Session): SocketManager { - val socketManager = super.newSession(user, session) - val ui = (socketManager as ApplicationSocketManager).applicationInterface - val settings = getSettings(session, user, Settings::class.java) ?: Settings() - val app = this - if (null == (fileInputs ?: settings.fileInputs)) { - log.info("No file input provided") - } else (fileInputs ?: settings.fileInputs).apply { - val progressBar = progressBar(ui.newTask()) - socketManager.pool.submit { - run( - mainTask = ui.newTask(), - ui = ui, - fileInputs = (app.fileInputs ?: settings.fileInputs?.map { File(it).toPath() } ?: error("File input not provided")), - maxPages = settings.maxPages.coerceAtMost(Int.MAX_VALUE), - settings = settings, - pagesPerBatch = settings.pagesPerBatch, - progressBar = progressBar, - ) - } - } - return socketManager + override fun newSession(user: User?, session: Session): SocketManager { + val socketManager = super.newSession(user, session) + val ui = (socketManager as ApplicationSocketManager).applicationInterface + val settings = getSettings(session, user, Settings::class.java) ?: Settings() + val app = this + if (null == (fileInputs ?: settings.fileInputs)) { + log.info("No file input provided") + } else (fileInputs ?: settings.fileInputs).apply { + val progressBar = progressBar(ui.newTask()) + socketManager.pool.submit { + run( + mainTask = ui.newTask(), + ui = ui, + fileInputs = (app.fileInputs ?: settings.fileInputs?.map { File(it).toPath() } ?: error("File input not provided")), + maxPages = settings.maxPages.coerceAtMost(Int.MAX_VALUE), + settings = settings, + pagesPerBatch = settings.pagesPerBatch, + progressBar = progressBar, + ) + } } + return socketManager + } - override fun userMessage(session: Session, user: User?, userMessage: String, ui: ApplicationInterface, api: API) { - val settings = getSettings(session, user, Settings::class.java) ?: Settings() - ui.socketManager!!.pool.submit { - run( - mainTask = ui.newTask(), - ui = ui, - fileInputs = (this.fileInputs ?: settings.fileInputs?.map { File(it).toPath() } ?: error("File input not provided")), - maxPages = settings.maxPages.coerceAtMost(Int.MAX_VALUE), - settings = settings, - pagesPerBatch = settings.pagesPerBatch, - ) - } + override fun userMessage(session: Session, user: User?, userMessage: String, ui: ApplicationInterface, api: API) { + val settings = getSettings(session, user, Settings::class.java) ?: Settings() + ui.socketManager!!.pool.submit { + run( + mainTask = ui.newTask(), + ui = ui, + fileInputs = (this.fileInputs ?: settings.fileInputs?.map { File(it).toPath() } ?: error("File input not provided")), + maxPages = settings.maxPages.coerceAtMost(Int.MAX_VALUE), + settings = settings, + pagesPerBatch = settings.pagesPerBatch, + ) } + } - private fun run( - mainTask: SessionTask, - ui: ApplicationInterface, - fileInputs: List, - maxPages: Int, - settings: Settings, - pagesPerBatch: Int, - progressBar: ProgressState? = null - ) { - try { - mainTask.header("PDF Extractor") - val api = (api as ChatClient).getChildClient().apply { - val createFile = mainTask.createFile(".logs/api-${UUID.randomUUID()}.log") + private fun run( + mainTask: SessionTask, + ui: ApplicationInterface, + fileInputs: List, + maxPages: Int, + settings: Settings, + pagesPerBatch: Int, + progressBar: ProgressState? = null + ) { + try { + mainTask.header("PDF Extractor") + val api = (api as ChatClient).getChildClient().apply { + val createFile = mainTask.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + mainTask.verbose("API log: $this") + } + } + val docTabs = TabbedDisplay(mainTask) + fileInputs.map { it.toFile() }.forEach { file -> + if (!file.exists()) { + mainTask.error(ui, IllegalArgumentException("File not found: $file")) + return + } + ui.socketManager?.pool?.submit { + val docTask = ui.newTask(false).apply { docTabs[file.toString()] = this.placeholder } + val pageTabs = TabbedDisplay(docTask) + val outputDir = root.resolve("output").apply { mkdirs() } + reader(file).use { reader -> + var previousPageText = "" // Keep this for context + val pageCount = minOf(reader.getPageCount(), maxPages) + val pageSets = 0 until pageCount step pagesPerBatch + progressBar?.add(0.0, pageCount.toDouble()) + val futures = pageSets.toList().mapNotNull { batchStart -> + val pageTask = ui.newTask(false) + val api = api.getChildClient().apply { + val createFile = pageTask.createFile(".logs/api-${UUID.randomUUID()}.log") createFile.second?.apply { - logStreams += this.outputStream().buffered() - mainTask.verbose("API log: $this") + logStreams += this.outputStream().buffered() + pageTask.verbose("API log: $this") } - } - val docTabs = TabbedDisplay(mainTask) - fileInputs.map { it.toFile() }.forEach { file -> - if (!file.exists()) { - mainTask.error(ui, IllegalArgumentException("File not found: $file")) - return + } + try { + val batchEnd = min(batchStart + pagesPerBatch, pageCount) + val text = reader.getText(batchStart, batchEnd) + val label = if ((batchStart + 1) != batchEnd) "Pages ${batchStart}-${batchEnd}" else "Page ${batchStart}" + val pageTabs = TabbedDisplay(pageTask.apply { pageTabs[label] = placeholder }) + if (settings.showImages) { + for (pageIndex in batchStart until batchEnd) { + val image = reader.renderImage(pageIndex, settings.dpi) + ui.newTask(false).apply { + pageTabs["Image ${1 + (pageIndex - batchStart)}"] = placeholder + image(image) + } + if (settings.saveImageFiles) { + val imageFile = + outputDir.resolve("page_${pageIndex}.${settings.outputFormat.lowercase(Locale.getDefault())}") + when (settings.outputFormat.uppercase(Locale.getDefault())) { + "PNG" -> ImageIO.write(image, "PNG", imageFile) + "JPEG", "JPG" -> ImageIO.write(image, "JPEG", imageFile) + "GIF" -> ImageIO.write(image, "GIF", imageFile) + "BMP" -> ImageIO.write(image, "BMP", imageFile) + else -> throw IllegalArgumentException("Unsupported output format: ${settings.outputFormat}") + } + } + } } - ui.socketManager?.pool?.submit { - val docTask = ui.newTask(false).apply { docTabs[file.toString()] = this.placeholder } - val pageTabs = TabbedDisplay(docTask) - val outputDir = root.resolve("output").apply { mkdirs() } - reader(file).use { reader -> - var previousPageText = "" // Keep this for context - val pageCount = minOf(reader.getPageCount(), maxPages) - val pageSets = 0 until pageCount step pagesPerBatch - progressBar?.add(0.0, pageCount.toDouble()) - val futures = pageSets.toList().mapNotNull { batchStart -> - val pageTask = ui.newTask(false) - val api = api.getChildClient().apply { - val createFile = pageTask.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - pageTask.verbose("API log: $this") - } - } - try { - val batchEnd = min(batchStart + pagesPerBatch, pageCount) - val text = reader.getText(batchStart, batchEnd) - val label = if ((batchStart + 1) != batchEnd) "Pages ${batchStart}-${batchEnd}" else "Page ${batchStart}" - val pageTabs = TabbedDisplay(pageTask.apply { pageTabs[label] = placeholder }) - if (settings.showImages) { - for (pageIndex in batchStart until batchEnd) { - val image = reader.renderImage(pageIndex, settings.dpi) - ui.newTask(false).apply { - pageTabs["Image ${1 + (pageIndex - batchStart)}"] = placeholder - image(image) - } - if (settings.saveImageFiles) { - val imageFile = - outputDir.resolve("page_${pageIndex}.${settings.outputFormat.lowercase(Locale.getDefault())}") - when (settings.outputFormat.uppercase(Locale.getDefault())) { - "PNG" -> ImageIO.write(image, "PNG", imageFile) - "JPEG", "JPG" -> ImageIO.write(image, "JPEG", imageFile) - "GIF" -> ImageIO.write(image, "GIF", imageFile) - "BMP" -> ImageIO.write(image, "BMP", imageFile) - else -> throw IllegalArgumentException("Unsupported output format: ${settings.outputFormat}") - } - } - } - } - if (text.isBlank()) { - pageTask.error(ui, IllegalArgumentException("No text extracted from pages $batchStart to $batchEnd")) - return@mapNotNull null - } - if (settings.saveTextFiles) { - outputDir.resolve("pages_${batchStart}_to_${batchEnd}_text.txt").writeText(text) - } - val promptList = mutableListOf() - promptList.add( - """ + if (text.isBlank()) { + pageTask.error(ui, IllegalArgumentException("No text extracted from pages $batchStart to $batchEnd")) + return@mapNotNull null + } + if (settings.saveTextFiles) { + outputDir.resolve("pages_${batchStart}_to_${batchEnd}_text.txt").writeText(text) + } + val promptList = mutableListOf() + promptList.add( + """ |# Prior Text | |FOR INFORMATIVE CONTEXT ONLY. DO NOT COPY TO OUTPUT. @@ -163,66 +163,70 @@ open class DocumentParserApp( |$previousPageText |``` |""".trimMargin() - ) - promptList.add( - """ + ) + promptList.add( + """ |# Current Page | |```text |$text |``` """.trimMargin() - ) - previousPageText = text - ui.socketManager.pool.submit { - try { - val jsonResult = parsingModel.getParser(api)(promptList.toList().joinToString("\n\n")) - if (settings.saveTextFiles) { - val jsonFile = outputDir.resolve("pages_${batchStart}_to_${batchEnd}_content.json") - jsonFile.writeText(JsonUtil.toJson(jsonResult)) - } - ui.newTask(false).apply { - pageTabs["Text"] = placeholder - add( - MarkdownUtil.renderMarkdown( - "\n```text\n${ - text - }\n```\n", ui = ui - ) - ) - } - ui.newTask(false).apply { - pageTabs["JSON"] = placeholder - add( - MarkdownUtil.renderMarkdown( - "\n```json\n${ - JsonUtil.toJson(jsonResult) - }\n```\n", ui = ui - ) - ) - } - jsonResult - } catch (e: Throwable) { - pageTask.error(ui, e) - null - } finally { - progressBar?.add(1.0, 0.0) - pageTask.complete() - } - } - } catch (e: Throwable) { - pageTask.error(ui, e) - null - } - }.toTypedArray() - val finalDocument = futures.mapNotNull { try { it.get() } catch (e : Throwable) { - mainTask.error(ui, e) - null - } }.fold(parsingModel.newDocument()) - { runningDocument, it -> parsingModel.merge(runningDocument, it) } - docTask.add( - MarkdownUtil.renderMarkdown( - """ + ) + previousPageText = text + ui.socketManager.pool.submit { + try { + val jsonResult = parsingModel.getParser(api)(promptList.toList().joinToString("\n\n")) + if (settings.saveTextFiles) { + val jsonFile = outputDir.resolve("pages_${batchStart}_to_${batchEnd}_content.json") + jsonFile.writeText(JsonUtil.toJson(jsonResult)) + } + ui.newTask(false).apply { + pageTabs["Text"] = placeholder + add( + MarkdownUtil.renderMarkdown( + "\n```text\n${ + text + }\n```\n", ui = ui + ) + ) + } + ui.newTask(false).apply { + pageTabs["JSON"] = placeholder + add( + MarkdownUtil.renderMarkdown( + "\n```json\n${ + JsonUtil.toJson(jsonResult) + }\n```\n", ui = ui + ) + ) + } + jsonResult + } catch (e: Throwable) { + pageTask.error(ui, e) + null + } finally { + progressBar?.add(1.0, 0.0) + pageTask.complete() + } + } + } catch (e: Throwable) { + pageTask.error(ui, e) + null + } + }.toTypedArray() + val finalDocument = futures.mapNotNull { + try { + it.get() + } catch (e: Throwable) { + mainTask.error(ui, e) + null + } + }.fold(parsingModel.newDocument()) + { runningDocument, it -> parsingModel.merge(runningDocument, it) } + docTask.add( + MarkdownUtil.renderMarkdown( + """ |## Document JSON | |```json @@ -231,51 +235,51 @@ open class DocumentParserApp( | |Extracted files are saved in: ${outputDir.absolutePath} """.trimMargin(), ui = ui - ) - ) - if (settings.saveFinalJson) { - val finalJsonFile = root.resolve(file.name.reversed().split(delimiters = arrayOf("."), false, 2)[1].reversed() + ".parsed.json") - finalJsonFile.writeText(JsonUtil.toJson(finalDocument)) - docTask.add( - MarkdownUtil.renderMarkdown( - "Final JSON saved to: ${finalJsonFile.absolutePath}", - ui = ui - ) - ) - } - } - } + ) + ) + if (settings.saveFinalJson) { + val finalJsonFile = file.parentFile.resolve(file.name.reversed().split(delimiters = arrayOf("."), false, 2).joinToString("_").reversed() + ".parsed.json") + finalJsonFile.writeText(JsonUtil.toJson(finalDocument)) + docTask.add( + MarkdownUtil.renderMarkdown( + "Final JSON saved to: ${finalJsonFile.absolutePath}", + ui = ui + ) + ) } - } catch (e: Throwable) { - mainTask.error(ui, e) + } } + } + } catch (e: Throwable) { + mainTask.error(ui, e) } + } - data class Settings( - val dpi: Float = 120f, - val maxPages: Int = Int.MAX_VALUE, - val outputFormat: String = "PNG", - val fileInputs: List? = null, - val showImages: Boolean = true, - val pagesPerBatch: Int = 1, - val saveImageFiles: Boolean = false, - val saveTextFiles: Boolean = false, - val saveFinalJson: Boolean = true - ) + data class Settings( + val dpi: Float = 120f, + val maxPages: Int = Int.MAX_VALUE, + val outputFormat: String = "PNG", + val fileInputs: List? = null, + val showImages: Boolean = true, + val pagesPerBatch: Int = 1, + val saveImageFiles: Boolean = false, + val saveTextFiles: Boolean = false, + val saveFinalJson: Boolean = true + ) - override val settingsClass: Class<*> get() = Settings::class.java + override val settingsClass: Class<*> get() = Settings::class.java - @Suppress("UNCHECKED_CAST") - override fun initSettings(session: Session): T = Settings() as T + @Suppress("UNCHECKED_CAST") + override fun initSettings(session: Session): T = Settings() as T - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(DocumentParserApp::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(DocumentParserApp::class.java) + } - interface DocumentReader : AutoCloseable { - fun getPageCount(): Int - fun getText(startPage: Int, endPage: Int): String - fun renderImage(pageIndex: Int, dpi: Float): BufferedImage - } + interface DocumentReader : AutoCloseable { + fun getPageCount(): Int + fun getText(startPage: Int, endPage: Int): String + fun renderImage(pageIndex: Int, dpi: Float): BufferedImage + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParsingModel.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParsingModel.kt index a91c64fd..91368ecc 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParsingModel.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentParsingModel.kt @@ -13,44 +13,44 @@ import java.util.concurrent.Future open class DocumentParsingModel( - private val parsingModel: ChatModel, - private val temperature: Double + private val parsingModel: ChatModel, + private val temperature: Double ) : ParsingModel { - override fun merge( - runningDocument: ParsingModel.DocumentData, - newData: ParsingModel.DocumentData - ): ParsingModel.DocumentData { - val runningDocument = runningDocument as DocumentData - val newData = newData as DocumentData - return DocumentData( - id = newData.id ?: runningDocument.id, - content_list = mergeContent(runningDocument.content_list, newData.content_list).takeIf { it.isNotEmpty() }, - ) - } - - protected open fun mergeContent( - existingContent: List?, - newContent: List? - ): List { - val mergedContent = (existingContent ?: emptyList()).toMutableList() - (newContent ?: emptyList()).forEach { newItem -> - val existingIndex = mergedContent.indexOfFirst { it.type == newItem.type && it.text?.trim() == newItem.text?.trim() } - if (existingIndex != -1) { - mergedContent[existingIndex] = mergeContentData(mergedContent[existingIndex], newItem) - } else { - mergedContent.add(newItem) - } - } - return mergedContent + override fun merge( + runningDocument: ParsingModel.DocumentData, + newData: ParsingModel.DocumentData + ): ParsingModel.DocumentData { + val runningDocument = runningDocument as DocumentData + val newData = newData as DocumentData + return DocumentData( + id = newData.id ?: runningDocument.id, + content_list = mergeContent(runningDocument.content_list, newData.content_list).takeIf { it.isNotEmpty() }, + ) + } + + protected open fun mergeContent( + existingContent: List?, + newContent: List? + ): List { + val mergedContent = (existingContent ?: emptyList()).toMutableList() + (newContent ?: emptyList()).forEach { newItem -> + val existingIndex = mergedContent.indexOfFirst { it.type == newItem.type && it.text?.trim() == newItem.text?.trim() } + if (existingIndex != -1) { + mergedContent[existingIndex] = mergeContentData(mergedContent[existingIndex], newItem) + } else { + mergedContent.add(newItem) + } } + return mergedContent + } - protected open fun mergeContentData(existing: ContentData, new: ContentData) = existing.copy( - content_list = mergeContent(existing.content_list, new.content_list).takeIf { it.isNotEmpty() }, - tags = ((existing.tags ?: emptyList()) + (new.tags ?: emptyList())).distinct().takeIf { it.isNotEmpty() } - ) + protected open fun mergeContentData(existing: ContentData, new: ContentData) = existing.copy( + content_list = mergeContent(existing.content_list, new.content_list).takeIf { it.isNotEmpty() }, + tags = ((existing.tags ?: emptyList()) + (new.tags ?: emptyList())).distinct().takeIf { it.isNotEmpty() } + ) - open val promptSuffix = """ + open val promptSuffix = """ |Parse the text into a hierarchical structure: |1. Separate the content into sections, paragraphs, statements, etc. |2. All source content should be included in the output, with paraphrasing, corrections, and context as needed @@ -58,79 +58,79 @@ open class DocumentParsingModel( |4. Assign relevant tags to each node to improve searchability and categorization. """.trimMargin() - open val exampleInstance = DocumentData() - - override fun getParser(api: API): (String) -> DocumentData { - val parser = ParsedActor( - resultClass = DocumentData::class.java, - exampleInstance = exampleInstance, - prompt = "", - parsingModel = parsingModel, - temperature = temperature - ).getParser( - api, promptSuffix = promptSuffix + open val exampleInstance = DocumentData() + + override fun getParser(api: API): (String) -> DocumentData { + val parser = ParsedActor( + resultClass = DocumentData::class.java, + exampleInstance = exampleInstance, + prompt = "", + parsingModel = parsingModel, + temperature = temperature + ).getParser( + api, promptSuffix = promptSuffix + ) + return { text -> parser.apply(text) } + } + + override fun newDocument() = DocumentData() + + data class DocumentData( + @Description("Document/Page identifier") override val id: String? = null, + @Description("Hierarchical structure and data") override val content_list: List? = null, + ) : ParsingModel.DocumentData + + data class ContentData( + @Description("Content type, e.g. heading, paragraph, statement, list") override val type: String = "", + @Description("Brief, self-contained text either copied, paraphrased, or summarized") override val text: String? = null, + @Description("Sub-elements") override val content_list: List? = null, + @Description("Tags - related topics and non-entity indexing") override val tags: List? = null + ) : ParsingModel.ContentData + + companion object { + val log = org.slf4j.LoggerFactory.getLogger(DocumentParsingModel::class.java) + + fun getRows( + inputPath: String, + progressState: ProgressState?, + futureList: MutableList>, + pool: ExecutorService, + openAIClient: OpenAIClient, + fileData: Map? + ): MutableList { + val records: MutableList = mutableListOf() + fun processContent(content: Map, path: String = "") { + val record = DocumentRecord( + text = content["text"] as? String, + metadata = JsonUtil.toJson(content.filter { it.key != "text" && it.key != "content" && it.key != "type" }), + sourcePath = inputPath, + jsonPath = path, + vector = null ) - return { text -> parser.apply(text) } - } - - override fun newDocument() = DocumentData() - - data class DocumentData( - @Description("Document/Page identifier") override val id: String? = null, - @Description("Hierarchical structure and data") override val content_list: List? = null, - ) : ParsingModel.DocumentData - - data class ContentData( - @Description("Content type, e.g. heading, paragraph, statement, list") override val type: String = "", - @Description("Brief, self-contained text either copied, paraphrased, or summarized") override val text: String? = null, - @Description("Sub-elements") override val content_list: List? = null, - @Description("Tags - related topics and non-entity indexing") override val tags: List? = null - ) : ParsingModel.ContentData - - companion object { - val log = org.slf4j.LoggerFactory.getLogger(DocumentParsingModel::class.java) - - fun getRows( - inputPath: String, - progressState: ProgressState?, - futureList: MutableList>, - pool: ExecutorService, - openAIClient: OpenAIClient, - fileData: Map? - ): MutableList { - val records: MutableList = mutableListOf() - fun processContent(content: Map, path: String = "") { - val record = DocumentRecord( - text = content["text"] as? String, - metadata = JsonUtil.toJson(content.filter { it.key != "text" && it.key != "content" && it.key != "type" }), - sourcePath = inputPath, - jsonPath = path, - vector = null - ) - records.add(record) - if (record.text != null) { - progressState?.add(0.0, 1.0) - futureList.add(pool.submit { - record.vector = openAIClient.createEmbedding( - ApiModel.EmbeddingRequest( - EmbeddingModels.Large.modelName, record.text - ) - ).data[0].embedding ?: DoubleArray(0) - progressState?.add(1.0, 0.0) - }) - } - (content["content_list"] as? List>)?.forEachIndexed> { index, childContent -> - processContent(childContent, "$path.content_list[$index]") - } - } - fileData?.get("content_list")?.let { contentList -> - (contentList as? List>)?.forEachIndexed> { index, content -> - processContent(content, "content_list[$index]") - } - } - return records + records.add(record) + if (record.text != null) { + progressState?.add(0.0, 1.0) + futureList.add(pool.submit { + record.vector = openAIClient.createEmbedding( + ApiModel.EmbeddingRequest( + EmbeddingModels.Large.modelName, record.text + ) + ).data[0].embedding ?: DoubleArray(0) + progressState?.add(1.0, 0.0) + }) } - + (content["content_list"] as? List>)?.forEachIndexed> { index, childContent -> + processContent(childContent, "$path.content_list[$index]") + } + } + fileData?.get("content_list")?.let { contentList -> + (contentList as? List>)?.forEachIndexed> { index, content -> + processContent(content, "content_list[$index]") + } + } + return records } + } + } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentRecord.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentRecord.kt index 3e5c6bf8..8fcf44a8 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentRecord.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/DocumentRecord.kt @@ -14,98 +14,97 @@ import java.util.concurrent.Future import java.util.concurrent.TimeUnit data class DocumentRecord( - val text: String?, - val metadata: String?, - val sourcePath: String, - val jsonPath: String, - var vector: DoubleArray?, + val text: String?, + val metadata: String?, + val sourcePath: String, + val jsonPath: String, + var vector: DoubleArray?, ) : Serializable { - @Throws(IOException::class) - fun writeObject(out: ObjectOutputStream) { - out.writeUTF(text ?: "") - out.writeUTF(metadata ?: "") - out.writeUTF(sourcePath) - out.writeUTF(jsonPath) - out.writeObject(vector) - } + @Throws(IOException::class) + fun writeObject(out: ObjectOutputStream) { + out.writeUTF(text ?: "") + out.writeUTF(metadata ?: "") + out.writeUTF(sourcePath) + out.writeUTF(jsonPath) + out.writeObject(vector) + } - @Throws(IOException::class, ClassNotFoundException::class) - fun readObject(input: ObjectInputStream): DocumentRecord { - val text = input.readUTF().let { if (it.isEmpty()) null else it } - val metadata = input.readUTF().let { if (it.isEmpty()) null else it } - val sourcePath = input.readUTF() - val jsonPath = input.readUTF() - val vector = input.readObject() as DoubleArray? - return DocumentRecord( - text, - metadata, - sourcePath, - jsonPath, - vector - ) - } + @Throws(IOException::class, ClassNotFoundException::class) + fun readObject(input: ObjectInputStream): DocumentRecord { + val text = input.readUTF().let { if (it.isEmpty()) null else it } + val metadata = input.readUTF().let { if (it.isEmpty()) null else it } + val sourcePath = input.readUTF() + val jsonPath = input.readUTF() + val vector = input.readObject() as DoubleArray? + return DocumentRecord( + text, + metadata, + sourcePath, + jsonPath, + vector + ) + } - companion object { - val log = org.slf4j.LoggerFactory.getLogger(DocumentRecord::class.java) + companion object { + val log = org.slf4j.LoggerFactory.getLogger(DocumentRecord::class.java) - fun saveAsBinary( - openAIClient: OpenAIClient, - pool: ExecutorService, - progressState: ProgressState? = null, - vararg inputPaths: String, - ) { - inputPaths.forEach { inputPath -> - val futureList = mutableListOf>() - val infile = File(inputPath) - val fileData = JsonUtil.fromJson(infile.readText(), Map::class.java) as T as? Map - val records = DocumentParsingModel.getRows(inputPath, progressState, futureList, pool, openAIClient, fileData) - val outputPath = infile.parentFile.resolve(infile.name.split("\\.".toRegex(), 2).first() + ".index.data").absolutePath - awaitAll(futureList.toTypedArray()) - writeBinary(outputPath, records) - } - } + fun saveAsBinary( + openAIClient: OpenAIClient, + pool: ExecutorService, + progressState: ProgressState? = null, + vararg inputPaths: String, + ) = inputPaths.map { inputPath -> + val futureList = mutableListOf>() + val infile = File(inputPath) + val fileData = JsonUtil.fromJson>(infile.readText(), Map::class.java) + val records = DocumentParsingModel.getRows(inputPath, progressState, futureList, pool, openAIClient, fileData) + val outputPath = infile.parentFile.resolve(infile.name.split("\\.".toRegex(), 2).first() + ".index.data").absolutePath + awaitAll(futureList.toTypedArray()) + writeBinary(outputPath, records) + outputPath + } - fun awaitAll(futureList: Array>) { - val start = System.currentTimeMillis() - for (future in futureList) { - try { - future.get( - TimeUnit.MINUTES.toMillis(5) - (System.currentTimeMillis() - start), - TimeUnit.MILLISECONDS - ) - } catch (e: Exception) { - log.error("Error processing entity", e) - } - } + fun awaitAll(futureList: Array>) { + val start = System.currentTimeMillis() + for (future in futureList) { + try { + future.get( + TimeUnit.MINUTES.toMillis(5) - (System.currentTimeMillis() - start), + TimeUnit.MILLISECONDS + ) + } catch (e: Exception) { + log.error("Error processing entity", e) } + } + } - private fun writeBinary(outputPath: String, records: List) { - log.info("Writing ${records.size} records to $outputPath") - ObjectOutputStream(FileOutputStream(outputPath)).use { out -> - out.writeInt(records.size) - records.forEach { it.writeObject(out) } - } - } + private fun writeBinary(outputPath: String, records: List) { + log.info("Writing ${records.size} records to $outputPath") + ObjectOutputStream(FileOutputStream(outputPath)).use { out -> + out.writeInt(records.size) + records.forEach { it.writeObject(out) } + } + } - fun readBinary(inputPath: String): List { - val records = mutableListOf() - ObjectInputStream(FileInputStream(inputPath)).use { input -> - val size = input.readInt() - repeat(size) { - records.add( - DocumentRecord( - text = null, - metadata = null, - sourcePath = "", - jsonPath = "", - vector = DoubleArray(0) - ).readObject(input) - ) - } - } - return records + fun readBinary(inputPath: String): List { + val records = mutableListOf() + ObjectInputStream(FileInputStream(inputPath)).use { input -> + val size = input.readInt() + repeat(size) { + records.add( + DocumentRecord( + text = null, + metadata = null, + sourcePath = "", + jsonPath = "", + vector = DoubleArray(0) + ).readObject(input) + ) } + } + return records } + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/PDFReader.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/PDFReader.kt index 6efd64dc..26fab490 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/PDFReader.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/PDFReader.kt @@ -7,23 +7,23 @@ import java.awt.image.BufferedImage import java.io.File class PDFReader(pdfFile: File) : DocumentParserApp.DocumentReader { - private val document: PDDocument = PDDocument.load(pdfFile) - private val renderer: PDFRenderer = PDFRenderer(document) + private val document: PDDocument = PDDocument.load(pdfFile) + private val renderer: PDFRenderer = PDFRenderer(document) - override fun getPageCount(): Int = document.numberOfPages + override fun getPageCount(): Int = document.numberOfPages - override fun getText(startPage: Int, endPage: Int): String { - val stripper = PDFTextStripper().apply { sortByPosition = true } // Not to be confused with the stripper from last night - stripper.startPage = startPage+1 - stripper.endPage = endPage+1 - return stripper.getText(document) - } + override fun getText(startPage: Int, endPage: Int): String { + val stripper = PDFTextStripper().apply { sortByPosition = true } // Not to be confused with the stripper from last night + stripper.startPage = startPage + 1 + stripper.endPage = endPage + 1 + return stripper.getText(document) + } - override fun renderImage(pageIndex: Int, dpi: Float): BufferedImage { - return renderer.renderImageWithDPI(pageIndex, dpi) - } + override fun renderImage(pageIndex: Int, dpi: Float): BufferedImage { + return renderer.renderImageWithDPI(pageIndex, dpi) + } - override fun close() { - document.close() - } + override fun close() { + document.close() + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/ParsingModel.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/ParsingModel.kt index f2c6c4e0..e382e031 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/ParsingModel.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/ParsingModel.kt @@ -3,25 +3,25 @@ package com.simiacryptus.skyenet.apps.parse import com.simiacryptus.jopenai.API interface ParsingModel { - fun merge(runningDocument: DocumentData, newData: DocumentData): DocumentData - fun getParser(api: API): (String) -> DocumentData - fun newDocument(): DocumentData + fun merge(runningDocument: DocumentData, newData: DocumentData): DocumentData + fun getParser(api: API): (String) -> DocumentData + fun newDocument(): DocumentData - interface DocumentMetadata - interface ContentData { - val type: String - val text: String? - val content_list: List? - val tags: List? - } + interface DocumentMetadata + interface ContentData { + val type: String + val text: String? + val content_list: List? + val tags: List? + } - interface DocumentData { - val id: String? - val content_list: List? + interface DocumentData { + val id: String? + val content_list: List? // val metadata: DocumentMetadata? - } + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ParsingModel::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ParsingModel::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/TextReader.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/TextReader.kt index 5d4d75d2..c55fa0af 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/TextReader.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/parse/TextReader.kt @@ -4,38 +4,40 @@ import java.awt.image.BufferedImage import java.io.File class TextReader(private val textFile: File) : DocumentParserApp.DocumentReader { - private val pages: List = splitIntoPages(textFile.readLines().joinToString("\n")) - - override fun getPageCount(): Int = pages.size - - override fun getText(startPage: Int, endPage: Int): String { - return pages.subList(startPage, endPage.coerceAtMost(pages.size)).joinToString("\n") - } - - override fun renderImage(pageIndex: Int, dpi: Float): BufferedImage { - throw UnsupportedOperationException("Text files do not support image rendering") - } - - override fun close() { - // No resources to close for text files - } - - private fun splitIntoPages(text: String, maxChars: Int = 16000): List { - if (text.length <= maxChars) return listOf(text) - val lines = text.split("\n") - if (lines.size <= 1) return listOf(text) - val splitFitnesses = lines.indices.map { i -> - val leftSize = lines.subList(0, i).joinToString("\n").length - val rightSize = lines.subList(i, lines.size).joinToString("\n").length - if (leftSize <= 0) return@map i to Double.MAX_VALUE - if (rightSize <= 0) return@map i to Double.MAX_VALUE - val fitness = -((leftSize.toDouble() / text.length) * Math.log1p(rightSize.toDouble() / text.length) + - (rightSize.toDouble() / text.length) * Math.log1p(leftSize.toDouble() / text.length)) - i to fitness.toDouble() - }.toTypedArray() - var bestSplitIndex = splitFitnesses.minByOrNull { it.second }?.first ?: lines.size / 2 - val leftText = lines.subList(0, bestSplitIndex).joinToString("\n") - val rightText = lines.subList(bestSplitIndex, lines.size).joinToString("\n") - return splitIntoPages(leftText, maxChars) + splitIntoPages(rightText, maxChars) - } + private val pages: List = splitIntoPages(textFile.readLines().joinToString("\n")) + + override fun getPageCount(): Int = pages.size + + override fun getText(startPage: Int, endPage: Int): String { + return pages.subList(startPage, endPage.coerceAtMost(pages.size)).joinToString("\n") + } + + override fun renderImage(pageIndex: Int, dpi: Float): BufferedImage { + throw UnsupportedOperationException("Text files do not support image rendering") + } + + override fun close() { + // No resources to close for text files + } + + private fun splitIntoPages(text: String, maxChars: Int = 16000): List { + if (text.length <= maxChars) return listOf(text) + val lines = text.split("\n") + if (lines.size <= 1) return listOf(text) + val splitFitnesses = lines.indices.map { i -> + val leftSize = lines.subList(0, i).joinToString("\n").length + val rightSize = lines.subList(i, lines.size).joinToString("\n").length + if (leftSize <= 0) return@map i to Double.MAX_VALUE + if (rightSize <= 0) return@map i to Double.MAX_VALUE + var fitness = -((leftSize.toDouble() / text.length) * Math.log1p(rightSize.toDouble() / text.length) + + (rightSize.toDouble() / text.length) * Math.log1p(leftSize.toDouble() / text.length)) + if (lines[i].isEmpty()) fitness *= 2 + i to fitness.toDouble() + }.toTypedArray().toMutableList() + + var bestSplitIndex = splitFitnesses.minByOrNull { it.second }?.first ?: lines.size / 2 + val leftText = lines.subList(0, bestSplitIndex).joinToString("\n") + val rightText = lines.subList(bestSplitIndex, lines.size).joinToString("\n") + return splitIntoPages(leftText, maxChars) + splitIntoPages(rightText, maxChars) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/AbstractTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/AbstractTask.kt index 00f86f98..adfe1519 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/AbstractTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/AbstractTask.kt @@ -10,60 +10,60 @@ import java.nio.file.Path abstract class AbstractTask( - val planSettings: PlanSettings, - val planTask: T? + val planSettings: PlanSettings, + val planTask: T? ) { - var state: TaskState? = TaskState.Pending - protected val codeFiles = mutableMapOf() + var state: TaskState? = TaskState.Pending + protected val codeFiles = mutableMapOf() - protected open val root: Path - get() = planSettings.workingDir?.let { File(it).toPath() } - ?: throw IllegalStateException("Working directory not set") + protected open val root: Path + get() = planSettings.workingDir?.let { File(it).toPath() } + ?: throw IllegalStateException("Working directory not set") - enum class TaskState { - Pending, - InProgress, - Completed, - } + enum class TaskState { + Pending, + InProgress, + Completed, + } - open fun getPriorCode(planProcessingState: PlanProcessingState) = - planTask?.task_dependencies?.joinToString("\n\n\n") { dependency -> - """ + open fun getPriorCode(planProcessingState: PlanProcessingState) = + planTask?.task_dependencies?.joinToString("\n\n\n") { dependency -> + """ |# $dependency | |${planProcessingState.taskResult[dependency] ?: ""} """.trimMargin() - } ?: "" + } ?: "" - protected fun acceptButtonFooter(ui: ApplicationInterface, fn: () -> Unit): String { - val footerTask = ui.newTask(false) - lateinit var textHandle: StringBuilder - textHandle = footerTask.complete(ui.hrefLink("Accept", classname = "href-link cmd-button") { - try { - textHandle.set("""
Accepted
""") - footerTask.complete() - } catch (e: Throwable) { - log.warn("Error", e) - } - fn() - })!! - return footerTask.placeholder - } + protected fun acceptButtonFooter(ui: ApplicationInterface, fn: () -> Unit): String { + val footerTask = ui.newTask(false) + lateinit var textHandle: StringBuilder + textHandle = footerTask.complete(ui.hrefLink("Accept", classname = "href-link cmd-button") { + try { + textHandle.set("""
Accepted
""") + footerTask.complete() + } catch (e: Throwable) { + log.warn("Error", e) + } + fn() + })!! + return footerTask.placeholder + } - abstract fun promptSegment(): String + abstract fun promptSegment(): String - abstract fun run( - agent: PlanCoordinator, - messages: List = listOf(), - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings, - ) + abstract fun run( + agent: PlanCoordinator, + messages: List = listOf(), + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings, + ) - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(AbstractTask::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(AbstractTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/CommandAutoFixTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/CommandAutoFixTask.kt index a1b821d9..9038f428 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/CommandAutoFixTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/CommandAutoFixTask.kt @@ -15,32 +15,32 @@ import java.util.concurrent.Semaphore import java.util.concurrent.atomic.AtomicBoolean class CommandAutoFixTask( - planSettings: PlanSettings, - planTask: CommandAutoFixTaskData? + planSettings: PlanSettings, + planTask: CommandAutoFixTaskData? ) : AbstractTask(planSettings, planTask) { - class CommandAutoFixTaskData( - @Description("The commands to be executed with their respective working directories") - val commands: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null - ) : PlanTaskBase( - task_type = TaskType.CommandAutoFix.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class CommandAutoFixTaskData( + @Description("The commands to be executed with their respective working directories") + val commands: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null + ) : PlanTaskBase( + task_type = TaskType.CommandAutoFix.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - data class CommandWithWorkingDir( - @Description("The command to be executed") - val command: List, - @Description("The relative path of the working directory") - val workingDir: String? = null - ) + data class CommandWithWorkingDir( + @Description("The command to be executed") + val command: List, + @Description("The relative path of the working directory") + val workingDir: String? = null + ) - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ CommandAutoFix - Run a command and automatically fix any issues that arise ** Specify the commands to be executed along with their working directories ** Each command's working directory should be specified relative to the root directory @@ -49,102 +49,102 @@ CommandAutoFix - Run a command and automatically fix any issues that arise ** Available commands: ${planSettings.commandAutoFixCommands?.joinToString("\n") { " * ${File(it).name}" }} """.trim() - } + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val semaphore = Semaphore(0) - val hasError = AtomicBoolean(false) - val onComplete = { semaphore.release() } - Retryable(agent.ui, task = task) { - val task = agent.ui.newTask(false).apply { it.append(placeholder) } - this.planTask?.commands?.forEachIndexed { index, commandWithDir -> - val alias = commandWithDir.command.firstOrNull() - val commandAutoFixCommands = agent.planSettings.commandAutoFixCommands - val cmds = commandAutoFixCommands - ?.map { File(it) }?.associateBy { it.name } - ?.filterKeys { it.startsWith(alias ?: "") } - ?: emptyMap() - var executable = cmds.entries.firstOrNull()?.value - executable = executable ?: alias?.let { root.resolve(it).toFile() } - if (executable == null) { - throw IllegalArgumentException("Command not found: $alias") - } - val workingDirectory = (commandWithDir.workingDir - ?.let { agent.root.toFile().resolve(it) } ?: agent.root.toFile()) - .apply { mkdirs() } - val outputResult = CmdPatchApp( - root = agent.root, - session = agent.session, - settings = PatchApp.Settings( - executable = executable, - arguments = commandWithDir.command.drop(1).joinToString(" "), - workingDirectory = workingDirectory, - exitCodeOption = "nonzero", - additionalInstructions = "", - autoFix = agent.planSettings.autoFix - ), - api = api as ChatClient, - files = agent.files, - model = agent.planSettings.getTaskSettings(TaskType.valueOf(planTask.task_type!!)).model - ?: agent.planSettings.defaultModel, - ).run( - ui = agent.ui, - task = task - ) - if (outputResult.exitCode != 0) { - hasError.set(true) - } - task.add(MarkdownUtil.renderMarkdown("## Command Auto Fix Result for Command ${index + 1}\n", ui = agent.ui, tabs = false)) - task.add( - if (outputResult.exitCode == 0) { - if (agent.planSettings.autoFix) { - MarkdownUtil.renderMarkdown("Auto-applied Command Auto Fix\n", ui = agent.ui, tabs = false) - } else { - MarkdownUtil.renderMarkdown( - "Command Auto Fix Result\n", - ui = agent.ui, tabs = false - ) - } - } else { - MarkdownUtil.renderMarkdown( - "Command Auto Fix Failed\n", - ui = agent.ui, tabs = false - ) - } - ) - } - resultFn("All Command Auto Fix tasks completed") - task.add(if (!hasError.get()) { - onComplete() - MarkdownUtil.renderMarkdown("## All Command Auto Fix tasks completed successfully\n", ui = agent.ui, tabs = false) - } else { - MarkdownUtil.renderMarkdown( - "## Some Command Auto Fix tasks failed\n", - ui = agent.ui - ) + acceptButtonFooter( - agent.ui - ) { - onComplete() - } - }) - task.placeholder + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val semaphore = Semaphore(0) + val hasError = AtomicBoolean(false) + val onComplete = { semaphore.release() } + Retryable(agent.ui, task = task) { + val task = agent.ui.newTask(false).apply { it.append(placeholder) } + this.planTask?.commands?.forEachIndexed { index, commandWithDir -> + val alias = commandWithDir.command.firstOrNull() + val commandAutoFixCommands = agent.planSettings.commandAutoFixCommands + val cmds = commandAutoFixCommands + ?.map { File(it) }?.associateBy { it.name } + ?.filterKeys { it.startsWith(alias ?: "") } + ?: emptyMap() + var executable = cmds.entries.firstOrNull()?.value + executable = executable ?: alias?.let { root.resolve(it).toFile() } + if (executable == null) { + throw IllegalArgumentException("Command not found: $alias") + } + val workingDirectory = (commandWithDir.workingDir + ?.let { agent.root.toFile().resolve(it) } ?: agent.root.toFile()) + .apply { mkdirs() } + val outputResult = CmdPatchApp( + root = agent.root, + session = agent.session, + settings = PatchApp.Settings( + executable = executable, + arguments = commandWithDir.command.drop(1).joinToString(" "), + workingDirectory = workingDirectory, + exitCodeOption = "nonzero", + additionalInstructions = "", + autoFix = agent.planSettings.autoFix + ), + api = api as ChatClient, + files = agent.files, + model = agent.planSettings.getTaskSettings(TaskType.valueOf(planTask.task_type!!)).model + ?: agent.planSettings.defaultModel, + ).run( + ui = agent.ui, + task = task + ) + if (outputResult.exitCode != 0) { + hasError.set(true) } - try { - semaphore.acquire() - } catch (e: Throwable) { - log.warn("Error", e) + task.add(MarkdownUtil.renderMarkdown("## Command Auto Fix Result for Command ${index + 1}\n", ui = agent.ui, tabs = false)) + task.add( + if (outputResult.exitCode == 0) { + if (agent.planSettings.autoFix) { + MarkdownUtil.renderMarkdown("Auto-applied Command Auto Fix\n", ui = agent.ui, tabs = false) + } else { + MarkdownUtil.renderMarkdown( + "Command Auto Fix Result\n", + ui = agent.ui, tabs = false + ) + } + } else { + MarkdownUtil.renderMarkdown( + "Command Auto Fix Failed\n", + ui = agent.ui, tabs = false + ) + } + ) + } + resultFn("All Command Auto Fix tasks completed") + task.add(if (!hasError.get()) { + onComplete() + MarkdownUtil.renderMarkdown("## All Command Auto Fix tasks completed successfully\n", ui = agent.ui, tabs = false) + } else { + MarkdownUtil.renderMarkdown( + "## Some Command Auto Fix tasks failed\n", + ui = agent.ui + ) + acceptButtonFooter( + agent.ui + ) { + onComplete() } + }) + task.placeholder } - - companion object { - private val log = LoggerFactory.getLogger(CommandAutoFixTask::class.java) + try { + semaphore.acquire() + } catch (e: Throwable) { + log.warn("Error", e) } + } + + companion object { + private val log = LoggerFactory.getLogger(CommandAutoFixTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/ForeachTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/ForeachTask.kt index 2e31ef83..adb409d7 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/ForeachTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/ForeachTask.kt @@ -10,73 +10,73 @@ import com.simiacryptus.skyenet.webui.session.SessionTask import org.slf4j.LoggerFactory class ForeachTask( - planSettings: PlanSettings, - planTask: ForeachTaskData? + planSettings: PlanSettings, + planTask: ForeachTaskData? ) : AbstractTask(planSettings, planTask) { - class ForeachTaskData( - @Description("A list of items over which the ForEach task will iterate. (Only applicable for ForeachTask tasks) Can be used to process outputs from previous tasks.") - val foreach_items: List? = null, - @Description("A map of sub-task IDs to PlanTask objects to be executed for each item. (Only applicable for ForeachTask tasks) Allows for complex task dependencies and information flow within iterations.") - val foreach_subplan: Map? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null, - ) : PlanTaskBase( - task_type = TaskType.ForeachTask.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class ForeachTaskData( + @Description("A list of items over which the ForEach task will iterate. (Only applicable for ForeachTask tasks) Can be used to process outputs from previous tasks.") + val foreach_items: List? = null, + @Description("A map of sub-task IDs to PlanTask objects to be executed for each item. (Only applicable for ForeachTask tasks) Allows for complex task dependencies and information flow within iterations.") + val foreach_subplan: Map? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = TaskType.ForeachTask.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ ForeachTask - Execute a task for each item in a list ** Specify the list of items to iterate over ** Define the task to be executed for each item """.trimIndent() - } + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val userMessage = messages.joinToString("\n") - val items = - planTask?.foreach_items ?: throw RuntimeException("No items specified for ForeachTask") - val subTasks = planTask.foreach_subplan ?: throw RuntimeException("No subTasks specified for ForeachTask") - val subPlanTask = agent.ui.newTask(false) - task.add(subPlanTask.placeholder) + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val userMessage = messages.joinToString("\n") + val items = + planTask?.foreach_items ?: throw RuntimeException("No items specified for ForeachTask") + val subTasks = planTask.foreach_subplan ?: throw RuntimeException("No subTasks specified for ForeachTask") + val subPlanTask = agent.ui.newTask(false) + task.add(subPlanTask.placeholder) - items.forEachIndexed { index, item -> - val itemSubTasks = subTasks.mapValues { (_, subTaskPlan) -> - subTaskPlan.task_description = "${subTaskPlan.task_description} - Item $index: $item" - subTaskPlan - } - val itemPlanProcessingState = PlanProcessingState(itemSubTasks) - agent.executePlan( - task = subPlanTask, - diagramBuffer = subPlanTask.add(diagram(agent.ui, itemPlanProcessingState.subTasks)), - subTasks = itemSubTasks, - diagramTask = subPlanTask, - planProcessingState = itemPlanProcessingState, - taskIdProcessingQueue = executionOrder(itemSubTasks).toMutableList(), - pool = agent.pool, - userMessage = "$userMessage\nProcessing item $index: $item", - plan = itemSubTasks, - api = api, - api2 = api2, - ) - } - subPlanTask.complete("Completed ForeachTask for ${items.size} items") + items.forEachIndexed { index, item -> + val itemSubTasks = subTasks.mapValues { (_, subTaskPlan) -> + subTaskPlan.task_description = "${subTaskPlan.task_description} - Item $index: $item" + subTaskPlan + } + val itemPlanProcessingState = PlanProcessingState(itemSubTasks) + agent.executePlan( + task = subPlanTask, + diagramBuffer = subPlanTask.add(diagram(agent.ui, itemPlanProcessingState.subTasks)), + subTasks = itemSubTasks, + diagramTask = subPlanTask, + planProcessingState = itemPlanProcessingState, + taskIdProcessingQueue = executionOrder(itemSubTasks).toMutableList(), + pool = agent.pool, + userMessage = "$userMessage\nProcessing item $index: $item", + plan = itemSubTasks, + api = api, + api2 = api2, + ) } + subPlanTask.complete("Completed ForeachTask for ${items.size} items") + } - companion object { - private val log = LoggerFactory.getLogger(ForeachTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(ForeachTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/GoogleSearchTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/GoogleSearchTask.kt index 31237901..74a4aa66 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/GoogleSearchTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/GoogleSearchTask.kt @@ -1,19 +1,19 @@ package com.simiacryptus.skyenet.apps.plan +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.kotlin.readValue +import com.simiacryptus.jopenai.ChatClient +import com.simiacryptus.jopenai.OpenAIClient import com.simiacryptus.jopenai.describe.Description import com.simiacryptus.skyenet.util.MarkdownUtil import com.simiacryptus.skyenet.webui.session.SessionTask +import com.simiacryptus.util.JsonUtil import org.slf4j.LoggerFactory import java.net.URI import java.net.URLEncoder import java.net.http.HttpClient import java.net.http.HttpRequest import java.net.http.HttpResponse -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.kotlin.readValue -import com.simiacryptus.jopenai.ChatClient -import com.simiacryptus.jopenai.OpenAIClient -import com.simiacryptus.util.JsonUtil class GoogleSearchTask( planSettings: PlanSettings, @@ -58,7 +58,8 @@ GoogleSearch - Search Google for web results private fun performGoogleSearch(planSettings: PlanSettings): String { val client = HttpClient.newBuilder().build() val encodedQuery = URLEncoder.encode(planTask?.search_query, "UTF-8") - val uriBuilder = "https://www.googleapis.com/customsearch/v1?key=${planSettings.googleApiKey}&cx=${planSettings.googleSearchEngineId}&q=$encodedQuery&num=${planTask?.num_results}" + val uriBuilder = + "https://www.googleapis.com/customsearch/v1?key=${planSettings.googleApiKey}&cx=${planSettings.googleSearchEngineId}&q=$encodedQuery&num=${planTask?.num_results}" val request = HttpRequest.newBuilder().uri(URI.create(uriBuilder)).GET().build() val response = client.send(request, HttpResponse.BodyHandlers.ofString()) if (response.statusCode() != 200) { diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanCoordinator.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanCoordinator.kt index 4e432ca4..0cc5f131 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanCoordinator.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanCoordinator.kt @@ -28,48 +28,48 @@ import java.util.concurrent.ThreadPoolExecutor import java.util.concurrent.TimeUnit class PlanCoordinator( - val user: User?, - val session: Session, - val dataStorage: StorageInterface, - val ui: ApplicationInterface, - val planSettings: PlanSettings, - val root: Path + val user: User?, + val session: Session, + val dataStorage: StorageInterface, + val ui: ApplicationInterface, + val planSettings: PlanSettings, + val root: Path ) { - val pool: ThreadPoolExecutor by lazy { ApplicationServices.clientManager.getPool(session, user) } + val pool: ThreadPoolExecutor by lazy { ApplicationServices.clientManager.getPool(session, user) } - val files: Array by lazy { - FileValidationUtils.expandFileList(root.toFile()) - } + val files: Array by lazy { + FileValidationUtils.expandFileList(root.toFile()) + } - val codeFiles: Map - get() = files - .filter { it.exists() && it.isFile } - .filter { !it.name.startsWith(".") } - .associate { file -> - root.relativize(file.toPath()) to try { - file.inputStream().bufferedReader().use { it.readText() } - } catch (e: Exception) { - log.warn("Error reading file", e) - "" - } - } + val codeFiles: Map + get() = files + .filter { it.exists() && it.isFile } + .filter { !it.name.startsWith(".") } + .associate { file -> + root.relativize(file.toPath()) to try { + file.inputStream().bufferedReader().use { it.readText() } + } catch (e: Exception) { + log.warn("Error reading file", e) + "" + } + } - fun executeTaskBreakdownWithPrompt( - jsonInput: String, - api: API, - api2: OpenAIClient, - task: SessionTask - ) { - try { - lateinit var taskBreakdownWithPrompt: TaskBreakdownWithPrompt - val plan = filterPlan { - taskBreakdownWithPrompt = JsonUtil.fromJson(jsonInput, TaskBreakdownWithPrompt::class.java) - taskBreakdownWithPrompt.plan - } - task.add( - MarkdownUtil.renderMarkdown( - """ + fun executeTaskBreakdownWithPrompt( + jsonInput: String, + api: API, + api2: OpenAIClient, + task: SessionTask + ) { + try { + lateinit var taskBreakdownWithPrompt: TaskBreakdownWithPrompt + val plan = filterPlan { + taskBreakdownWithPrompt = JsonUtil.fromJson(jsonInput, TaskBreakdownWithPrompt::class.java) + taskBreakdownWithPrompt.plan + } + task.add( + MarkdownUtil.renderMarkdown( + """ |## Executing TaskBreakdownWithPrompt |Prompt: ${taskBreakdownWithPrompt.prompt} |Plan Text: @@ -77,158 +77,158 @@ class PlanCoordinator( |${taskBreakdownWithPrompt.planText} |``` """.trimMargin(), ui = ui - ) - ) - executePlan(plan ?: emptyMap(), task, taskBreakdownWithPrompt.prompt, api, api2) - } catch (e: Exception) { - task.error(ui, e) - } + ) + ) + executePlan(plan ?: emptyMap(), task, taskBreakdownWithPrompt.prompt, api, api2) + } catch (e: Exception) { + task.error(ui, e) } + } - fun executePlan( - plan: Map, - task: SessionTask, - userMessage: String, - api: API, - api2: OpenAIClient, - ): PlanProcessingState { - val api = (api as ChatClient).getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - val planProcessingState = newState(plan) - try { - val diagramTask = ui.newTask(false).apply { task.add(placeholder) } - executePlan( - task = task, - diagramBuffer = diagramTask.add( - MarkdownUtil.renderMarkdown( - "## Task Dependency Graph\n${TRIPLE_TILDE}mermaid\n${buildMermaidGraph(planProcessingState.subTasks)}\n$TRIPLE_TILDE", - ui = ui - ) - ), - subTasks = planProcessingState.subTasks, - diagramTask = diagramTask, - planProcessingState = planProcessingState, - taskIdProcessingQueue = planProcessingState.taskIdProcessingQueue, - pool = pool, - userMessage = userMessage, - plan = plan, - api = api, - api2 = api2, - ) - } catch (e: Throwable) { - log.warn("Error during incremental code generation process", e) - task.error(ui, e) - } - return planProcessingState + fun executePlan( + plan: Map, + task: SessionTask, + userMessage: String, + api: API, + api2: OpenAIClient, + ): PlanProcessingState { + val api = (api as ChatClient).getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } + val planProcessingState = newState(plan) + try { + val diagramTask = ui.newTask(false).apply { task.add(placeholder) } + executePlan( + task = task, + diagramBuffer = diagramTask.add( + MarkdownUtil.renderMarkdown( + "## Task Dependency Graph\n${TRIPLE_TILDE}mermaid\n${buildMermaidGraph(planProcessingState.subTasks)}\n$TRIPLE_TILDE", + ui = ui + ) + ), + subTasks = planProcessingState.subTasks, + diagramTask = diagramTask, + planProcessingState = planProcessingState, + taskIdProcessingQueue = planProcessingState.taskIdProcessingQueue, + pool = pool, + userMessage = userMessage, + plan = plan, + api = api, + api2 = api2, + ) + } catch (e: Throwable) { + log.warn("Error during incremental code generation process", e) + task.error(ui, e) } + return planProcessingState + } - private fun newState(plan: Map) = - PlanProcessingState( - subTasks = (filterPlan { plan }?.entries?.toTypedArray>() - ?.associate { it.key to it.value } ?: mapOf()).toMutableMap() - ) + private fun newState(plan: Map) = + PlanProcessingState( + subTasks = (filterPlan { plan }?.entries?.toTypedArray>() + ?.associate { it.key to it.value } ?: mapOf()).toMutableMap() + ) - fun executePlan( - task: SessionTask, - diagramBuffer: StringBuilder?, - subTasks: Map, - diagramTask: SessionTask, - planProcessingState: PlanProcessingState, - taskIdProcessingQueue: MutableList, - pool: ThreadPoolExecutor, - userMessage: String, - plan: Map, - api: API, - api2: OpenAIClient, - ) { - val sessionTask = ui.newTask(false).apply { task.add(placeholder) } - val api = (api as ChatClient).getChildClient().apply { - val createFile = sessionTask.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - sessionTask.verbose("API log: $this") - } - } - val taskTabs = object : TabbedDisplay(sessionTask) { - override fun renderTabButtons(): String { - diagramBuffer?.set( - MarkdownUtil.renderMarkdown( - """ + fun executePlan( + task: SessionTask, + diagramBuffer: StringBuilder?, + subTasks: Map, + diagramTask: SessionTask, + planProcessingState: PlanProcessingState, + taskIdProcessingQueue: MutableList, + pool: ThreadPoolExecutor, + userMessage: String, + plan: Map, + api: API, + api2: OpenAIClient, + ) { + val sessionTask = ui.newTask(false).apply { task.add(placeholder) } + val api = (api as ChatClient).getChildClient().apply { + val createFile = sessionTask.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + sessionTask.verbose("API log: $this") + } + } + val taskTabs = object : TabbedDisplay(sessionTask) { + override fun renderTabButtons(): String { + diagramBuffer?.set( + MarkdownUtil.renderMarkdown( + """ |## Task Dependency Graph |${TRIPLE_TILDE}mermaid |${buildMermaidGraph(subTasks)} |$TRIPLE_TILDE """.trimMargin(), ui = ui - ) - ) - diagramTask.complete() - return buildString { - append("
\n") - super.tabs.withIndex().forEach { (idx, t) -> - val (taskId, taskV) = t - val subTask = planProcessingState.tasksByDescription[taskId] - if (null == subTask) { - log.warn("Task tab not found: $taskId") - } - val isChecked = if (taskId in taskIdProcessingQueue) "checked" else "" - val style = when (subTask?.state) { - AbstractTask.TaskState.Completed -> " style='text-decoration: line-through;'" - null -> " style='opacity: 20%;'" - AbstractTask.TaskState.Pending -> " style='opacity: 30%;'" - else -> "" - } - append("
\n") - } - append("
") - } + ) + ) + diagramTask.complete() + return buildString { + append("
\n") + super.tabs.withIndex().forEach { (idx, t) -> + val (taskId, taskV) = t + val subTask = planProcessingState.tasksByDescription[taskId] + if (null == subTask) { + log.warn("Task tab not found: $taskId") } + val isChecked = if (taskId in taskIdProcessingQueue) "checked" else "" + val style = when (subTask?.state) { + AbstractTask.TaskState.Completed -> " style='text-decoration: line-through;'" + null -> " style='opacity: 20%;'" + AbstractTask.TaskState.Pending -> " style='opacity: 30%;'" + else -> "" + } + append("
\n") + } + append("
") } - taskIdProcessingQueue.forEach { taskId -> - val newTask = ui.newTask(false) - planProcessingState.uitaskMap[taskId] = newTask - val subtask = planProcessingState.subTasks[taskId] - val description = subtask?.task_description - log.debug("Creating task tab: $taskId ${System.identityHashCode(subtask)} $description") - taskTabs[description ?: taskId] = newTask.placeholder + } + } + taskIdProcessingQueue.forEach { taskId -> + val newTask = ui.newTask(false) + planProcessingState.uitaskMap[taskId] = newTask + val subtask = planProcessingState.subTasks[taskId] + val description = subtask?.task_description + log.debug("Creating task tab: $taskId ${System.identityHashCode(subtask)} $description") + taskTabs[description ?: taskId] = newTask.placeholder + } + Thread.sleep(100) + while (taskIdProcessingQueue.isNotEmpty()) { + val taskId = taskIdProcessingQueue.removeAt(0) + val subTask = planProcessingState.subTasks[taskId] ?: throw RuntimeException("Task not found: $taskId") + planProcessingState.taskFutures[taskId] = pool.submit { + subTask.state = AbstractTask.TaskState.Pending + log.debug("Awaiting dependencies: ${subTask.task_dependencies?.joinToString(", ") ?: ""}") + subTask.task_dependencies + ?.associate { it to planProcessingState.taskFutures[it] } + ?.forEach { (id, future) -> + try { + future?.get() ?: log.warn("Dependency not found: $id") + } catch (e: Throwable) { + log.warn("Error", e) + } + } + subTask.state = AbstractTask.TaskState.InProgress + taskTabs.update() + log.debug("Running task: ${System.identityHashCode(subTask)} ${subTask.task_description}") + val task1 = planProcessingState.uitaskMap.get(taskId) ?: ui.newTask(false).apply { + taskTabs[taskId] = placeholder } - Thread.sleep(100) - while (taskIdProcessingQueue.isNotEmpty()) { - val taskId = taskIdProcessingQueue.removeAt(0) - val subTask = planProcessingState.subTasks[taskId] ?: throw RuntimeException("Task not found: $taskId") - planProcessingState.taskFutures[taskId] = pool.submit { - subTask.state = AbstractTask.TaskState.Pending - log.debug("Awaiting dependencies: ${subTask.task_dependencies?.joinToString(", ") ?: ""}") - subTask.task_dependencies - ?.associate { it to planProcessingState.taskFutures[it] } - ?.forEach { (id, future) -> - try { - future?.get() ?: log.warn("Dependency not found: $id") - } catch (e: Throwable) { - log.warn("Error", e) - } - } - subTask.state = AbstractTask.TaskState.InProgress - taskTabs.update() - log.debug("Running task: ${System.identityHashCode(subTask)} ${subTask.task_description}") - val task1 = planProcessingState.uitaskMap.get(taskId) ?: ui.newTask(false).apply { - taskTabs[taskId] = placeholder - } - try { - val dependencies = subTask.task_dependencies?.toMutableSet() ?: mutableSetOf() - dependencies += getAllDependencies( - subPlanTask = subTask, - subTasks = planProcessingState.subTasks, - visited = mutableSetOf() - ) + try { + val dependencies = subTask.task_dependencies?.toMutableSet() ?: mutableSetOf() + dependencies += getAllDependencies( + subPlanTask = subTask, + subTasks = planProcessingState.subTasks, + visited = mutableSetOf() + ) - task1.add( - MarkdownUtil.renderMarkdown( - """ + task1.add( + MarkdownUtil.renderMarkdown( + """ |## Task `${taskId}` |${subTask.task_description ?: ""} | @@ -240,70 +240,70 @@ class PlanCoordinator( |${dependencies.joinToString("\n") { "- $it" }} | """.trimMargin(), ui = ui - ) - ) - val api = api.getChildClient().apply { - val createFile = task1.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task1.verbose("API log: $this") - } - } - val impl = getImpl(planSettings, subTask) - val messages = listOf( - userMessage, - JsonUtil.toJson(plan), - impl.getPriorCode(planProcessingState) - ) - impl.run( - agent = this, - messages = messages, - task = task1, - api = api, - api2 = api2, - resultFn = { planProcessingState.taskResult[taskId] = it }, - planSettings = planSettings - ) - } catch (e: Throwable) { - log.warn("Error during task execution", e) - task1.error(ui, e) - } finally { - planProcessingState.completedTasks.add(element = taskId) - subTask.state = AbstractTask.TaskState.Completed - log.debug("Completed task: $taskId ${System.identityHashCode(subTask)}") - taskTabs.update() - } + ) + ) + val api = api.getChildClient().apply { + val createFile = task1.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task1.verbose("API log: $this") } + } + val impl = getImpl(planSettings, subTask) + val messages = listOf( + userMessage, + JsonUtil.toJson(plan), + impl.getPriorCode(planProcessingState) + ) + impl.run( + agent = this, + messages = messages, + task = task1, + api = api, + api2 = api2, + resultFn = { planProcessingState.taskResult[taskId] = it }, + planSettings = planSettings + ) + } catch (e: Throwable) { + log.warn("Error during task execution", e) + task1.error(ui, e) + } finally { + planProcessingState.completedTasks.add(element = taskId) + subTask.state = AbstractTask.TaskState.Completed + log.debug("Completed task: $taskId ${System.identityHashCode(subTask)}") + taskTabs.update() } - await(planProcessingState.taskFutures) + } } + await(planProcessingState.taskFutures) + } - fun await(futures: MutableMap>) { - val start = System.currentTimeMillis() - while (futures.values.count { it.isDone } < futures.size && (System.currentTimeMillis() - start) < TimeUnit.MINUTES.toMillis(2)) { - Thread.sleep(1000) - } + fun await(futures: MutableMap>) { + val start = System.currentTimeMillis() + while (futures.values.count { it.isDone } < futures.size && (System.currentTimeMillis() - start) < TimeUnit.MINUTES.toMillis(2)) { + Thread.sleep(1000) } + } - fun copy( - user: User? = this.user, - session: Session = this.session, - dataStorage: StorageInterface = this.dataStorage, - ui: ApplicationInterface = this.ui, - planSettings: PlanSettings = this.planSettings, - root: Path = this.root - ) = PlanCoordinator( - user = user, - session = session, - dataStorage = dataStorage, - ui = ui, - planSettings = planSettings, - root = root - ) + fun copy( + user: User? = this.user, + session: Session = this.session, + dataStorage: StorageInterface = this.dataStorage, + ui: ApplicationInterface = this.ui, + planSettings: PlanSettings = this.planSettings, + root: Path = this.root + ) = PlanCoordinator( + user = user, + session = session, + dataStorage = dataStorage, + ui = ui, + planSettings = planSettings, + root = root + ) - companion object : Planner() { - private val log = LoggerFactory.getLogger(PlanCoordinator::class.java) - } + companion object : Planner() { + private val log = LoggerFactory.getLogger(PlanCoordinator::class.java) + } } const val TRIPLE_TILDE = "```" \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanProcessingState.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanProcessingState.kt index 2888b31d..700cd3f7 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanProcessingState.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanProcessingState.kt @@ -4,12 +4,12 @@ import com.simiacryptus.skyenet.webui.session.SessionTask import java.util.concurrent.Future data class PlanProcessingState( - val subTasks: Map, - val tasksByDescription: MutableMap = subTasks.entries.toTypedArray() - .associate { it.value.task_description to it.value }.toMutableMap(), - val taskIdProcessingQueue: MutableList = PlanUtil.executionOrder(subTasks).toMutableList(), - val taskResult: MutableMap = mutableMapOf(), - val completedTasks: MutableList = mutableListOf(), - val taskFutures: MutableMap> = mutableMapOf(), - val uitaskMap: MutableMap = mutableMapOf() + val subTasks: Map, + val tasksByDescription: MutableMap = subTasks.entries.toTypedArray() + .associate { it.value.task_description to it.value }.toMutableMap(), + val taskIdProcessingQueue: MutableList = PlanUtil.executionOrder(subTasks).toMutableList(), + val taskResult: MutableMap = mutableMapOf(), + val completedTasks: MutableList = mutableListOf(), + val taskFutures: MutableMap> = mutableMapOf(), + val uitaskMap: MutableMap = mutableMapOf() ) \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanSettings.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanSettings.kt index 5c457d73..525c1e85 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanSettings.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanSettings.kt @@ -2,7 +2,6 @@ package com.simiacryptus.skyenet.apps.plan import com.simiacryptus.jopenai.describe.AbbrevWhitelistYamlDescriber import com.simiacryptus.jopenai.models.ChatModel -import com.simiacryptus.jopenai.models.TextModel import com.simiacryptus.skyenet.apps.plan.CommandAutoFixTask.CommandAutoFixTaskData import com.simiacryptus.skyenet.apps.plan.PlanUtil.isWindows import com.simiacryptus.skyenet.apps.plan.PlanningTask.PlanningTaskData @@ -14,76 +13,76 @@ import com.simiacryptus.skyenet.core.actors.ParsedActor data class TaskSettings( - var enabled: Boolean = false, - var model: ChatModel? = null + var enabled: Boolean = false, + var model: ChatModel? = null ) open class PlanSettings( - var defaultModel: ChatModel, - var parsingModel: ChatModel, - val command: List = listOf(if (isWindows) "powershell" else "bash"), - var temperature: Double = 0.2, - val budget: Double = 2.0, - val taskSettings: MutableMap = TaskType.values().associateWith { taskType -> - TaskSettings( - when (taskType) { - TaskType.FileModification, TaskType.Inquiry -> true - else -> false - } - ) - }.mapKeys { it.key.name }.toMutableMap(), - var autoFix: Boolean = false, - var allowBlocking: Boolean = true, - var commandAutoFixCommands: List? = listOf(), - val env: Map? = mapOf(), - val workingDir: String? = ".", - val language: String? = if (isWindows) "powershell" else "bash", - var githubToken: String? = null, - var googleApiKey: String? = null, - var googleSearchEngineId: String? = null, + var defaultModel: ChatModel, + var parsingModel: ChatModel, + val command: List = listOf(if (isWindows) "powershell" else "bash"), + var temperature: Double = 0.2, + val budget: Double = 2.0, + val taskSettings: MutableMap = TaskType.values().associateWith { taskType -> + TaskSettings( + when (taskType) { + TaskType.FileModification, TaskType.Inquiry -> true + else -> false + } + ) + }.mapKeys { it.key.name }.toMutableMap(), + var autoFix: Boolean = false, + var allowBlocking: Boolean = true, + var commandAutoFixCommands: List? = listOf(), + val env: Map? = mapOf(), + val workingDir: String? = ".", + val language: String? = if (isWindows) "powershell" else "bash", + var githubToken: String? = null, + var googleApiKey: String? = null, + var googleSearchEngineId: String? = null, ) { - fun getTaskSettings(taskType: TaskType<*>): TaskSettings = - taskSettings[taskType.name] ?: TaskSettings() + fun getTaskSettings(taskType: TaskType<*>): TaskSettings = + taskSettings[taskType.name] ?: TaskSettings() - fun setTaskSettings(taskType: TaskType<*>, settings: TaskSettings) { - taskSettings[taskType.name] = settings - } + fun setTaskSettings(taskType: TaskType<*>, settings: TaskSettings) { + taskSettings[taskType.name] = settings + } - fun copy( - model: ChatModel = this.defaultModel, - parsingModel: ChatModel = this.parsingModel, - command: List = this.command, - temperature: Double = this.temperature, - budget: Double = this.budget, - taskSettings: MutableMap = this.taskSettings, - autoFix: Boolean = this.autoFix, - allowBlocking: Boolean = this.allowBlocking, - commandAutoFixCommands: List? = this.commandAutoFixCommands, - env: Map? = this.env, - workingDir: String? = this.workingDir, - language: String? = this.language, - ) = PlanSettings( - defaultModel = model, - parsingModel = parsingModel, - command = command, - temperature = temperature, - budget = budget, - taskSettings = taskSettings, - autoFix = autoFix, - allowBlocking = allowBlocking, - commandAutoFixCommands = commandAutoFixCommands, - env = env, - workingDir = workingDir, - language = language, - githubToken = this.githubToken, - googleApiKey = this.googleApiKey, - googleSearchEngineId = this.googleSearchEngineId, - ) + fun copy( + model: ChatModel = this.defaultModel, + parsingModel: ChatModel = this.parsingModel, + command: List = this.command, + temperature: Double = this.temperature, + budget: Double = this.budget, + taskSettings: MutableMap = this.taskSettings, + autoFix: Boolean = this.autoFix, + allowBlocking: Boolean = this.allowBlocking, + commandAutoFixCommands: List? = this.commandAutoFixCommands, + env: Map? = this.env, + workingDir: String? = this.workingDir, + language: String? = this.language, + ) = PlanSettings( + defaultModel = model, + parsingModel = parsingModel, + command = command, + temperature = temperature, + budget = budget, + taskSettings = taskSettings, + autoFix = autoFix, + allowBlocking = allowBlocking, + commandAutoFixCommands = commandAutoFixCommands, + env = env, + workingDir = workingDir, + language = language, + githubToken = this.githubToken, + googleApiKey = this.googleApiKey, + googleSearchEngineId = this.googleSearchEngineId, + ) - fun planningActor(): ParsedActor { - val planTaskSettings = this.getTaskSettings(TaskType.TaskPlanning) - val prompt = """ + fun planningActor(): ParsedActor { + val planTaskSettings = this.getTaskSettings(TaskType.TaskPlanning) + val prompt = """ |Given a user request, identify and list smaller, actionable tasks that can be directly implemented in code. | |For each task: @@ -94,78 +93,78 @@ open class PlanSettings( | |Tasks can be of the following types: |${ - getAvailableTaskTypes(this).joinToString("\n") { taskType -> - "* ${getImpl(this, taskType).promptSegment()}" - } - } + getAvailableTaskTypes(this).joinToString("\n") { taskType -> + "* ${getImpl(this, taskType).promptSegment()}" + } + } | |Creating directories and initializing source control are out of scope. |${if (planTaskSettings.enabled) "Do not start your plan with a plan to plan!\n" else ""} """.trimMargin() - val describer = describer() - val parserPrompt = """ + val describer = describer() + val parserPrompt = """ Task Subtype Schema: ${ - getAvailableTaskTypes(this).joinToString("\n\n") { taskType -> - """ + getAvailableTaskTypes(this).joinToString("\n\n") { taskType -> + """ ${taskType.name}: ${describer.describe(taskType.taskDataClass).replace("\n", "\n ")} """.trim() - } - } - """.trimIndent() - return ParsedActor( - name = "TaskBreakdown", - resultClass = TaskBreakdownResult::class.java, - exampleInstance = exampleInstance, - prompt = prompt, - model = planTaskSettings.model ?: this.defaultModel, - parsingModel = this.parsingModel, - temperature = this.temperature, - describer = describer, - parserPrompt = parserPrompt - ) + } } + """.trimIndent() + return ParsedActor( + name = "TaskBreakdown", + resultClass = TaskBreakdownResult::class.java, + exampleInstance = exampleInstance, + prompt = prompt, + model = planTaskSettings.model ?: this.defaultModel, + parsingModel = this.parsingModel, + temperature = this.temperature, + describer = describer, + parserPrompt = parserPrompt + ) + } - open fun describer() = object : AbbrevWhitelistYamlDescriber( - "com.simiacryptus", "com.github.simiacryptus" - ) { - override val includeMethods: Boolean get() = false + open fun describer() = object : AbbrevWhitelistYamlDescriber( + "com.simiacryptus", "com.github.simiacryptus" + ) { + override val includeMethods: Boolean get() = false - override fun getEnumValues(clazz: Class<*>): List { - return if (clazz == TaskType::class.java) { - taskSettings.filter { it.value.enabled }.map { it.key.toString() } - } else { - super.getEnumValues(clazz) - } - } + override fun getEnumValues(clazz: Class<*>): List { + return if (clazz == TaskType::class.java) { + taskSettings.filter { it.value.enabled }.map { it.key.toString() } + } else { + super.getEnumValues(clazz) + } } + } - companion object { - var exampleInstance = TaskBreakdownResult( - tasksByID = mapOf( - "1" to CommandAutoFixTaskData( - task_description = "Task 1", - task_dependencies = listOf(), - commands = listOf( - CommandAutoFixTask.CommandWithWorkingDir( - command = listOf("echo", "Hello, World!"), - workingDir = "." - ) - ) - ), - "2" to FileModificationTaskData( - task_description = "Task 2", - task_dependencies = listOf("1"), - input_files = listOf("input2.txt"), - output_files = listOf("output2.txt"), - ), - "3" to PlanningTaskData( - task_description = "Task 3", - task_dependencies = listOf("2"), - ) - ), + companion object { + var exampleInstance = TaskBreakdownResult( + tasksByID = mapOf( + "1" to CommandAutoFixTaskData( + task_description = "Task 1", + task_dependencies = listOf(), + commands = listOf( + CommandAutoFixTask.CommandWithWorkingDir( + command = listOf("echo", "Hello, World!"), + workingDir = "." + ) + ) + ), + "2" to FileModificationTaskData( + task_description = "Task 2", + task_dependencies = listOf("1"), + input_files = listOf("input2.txt"), + output_files = listOf("output2.txt"), + ), + "3" to PlanningTaskData( + task_description = "Task 3", + task_dependencies = listOf("2"), ) - } + ), + ) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanUtil.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanUtil.kt index 761a9292..4cfcbaeb 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanUtil.kt @@ -11,171 +11,171 @@ import java.util.concurrent.ConcurrentHashMap object PlanUtil { - fun diagram( - ui: ApplicationInterface, - taskMap: Map - ) = MarkdownUtil.renderMarkdown( - """ + fun diagram( + ui: ApplicationInterface, + taskMap: Map + ) = MarkdownUtil.renderMarkdown( + """ |## Sub-Plan Task Dependency Graph |${TRIPLE_TILDE}mermaid |${buildMermaidGraph(taskMap)} |$TRIPLE_TILDE """.trimMargin(), - ui = ui - ) + ui = ui + ) - fun render( - withPrompt: TaskBreakdownWithPrompt, - ui: ApplicationInterface - ) = AgentPatterns.displayMapInTabs( - mapOf( - "Text" to MarkdownUtil.renderMarkdown(withPrompt.planText, ui = ui), - "JSON" to MarkdownUtil.renderMarkdown( - "${TRIPLE_TILDE}json\n${JsonUtil.toJson(withPrompt)}\n$TRIPLE_TILDE", - ui = ui - ), - "Diagram" to MarkdownUtil.renderMarkdown( - "```mermaid\n" + buildMermaidGraph( - (filterPlan { - withPrompt.plan - } ?: emptyMap()).toMutableMap() - ) + "\n```\n", ui = ui - ) - ) + fun render( + withPrompt: TaskBreakdownWithPrompt, + ui: ApplicationInterface + ) = AgentPatterns.displayMapInTabs( + mapOf( + "Text" to MarkdownUtil.renderMarkdown(withPrompt.planText, ui = ui), + "JSON" to MarkdownUtil.renderMarkdown( + "${TRIPLE_TILDE}json\n${JsonUtil.toJson(withPrompt)}\n$TRIPLE_TILDE", + ui = ui + ), + "Diagram" to MarkdownUtil.renderMarkdown( + "```mermaid\n" + buildMermaidGraph( + (filterPlan { + withPrompt.plan + } ?: emptyMap()).toMutableMap() + ) + "\n```\n", ui = ui + ) ) + ) - fun executionOrder(tasks: Map): List { - val taskIds: MutableList = mutableListOf() - val taskMap = tasks.toMutableMap() - while (taskMap.isNotEmpty()) { - val nextTasks = - taskMap.filter { (_, task) -> - task.task_dependencies?.filter { entry -> - entry in tasks.keys - }?.all { taskIds.contains(it) } ?: true - } - if (nextTasks.isEmpty()) { - throw RuntimeException("Circular dependency detected in task breakdown") - } - taskIds.addAll(nextTasks.keys) - nextTasks.keys.forEach { taskMap.remove(it) } + fun executionOrder(tasks: Map): List { + val taskIds: MutableList = mutableListOf() + val taskMap = tasks.toMutableMap() + while (taskMap.isNotEmpty()) { + val nextTasks = + taskMap.filter { (_, task) -> + task.task_dependencies?.filter { entry -> + entry in tasks.keys + }?.all { taskIds.contains(it) } ?: true } - return taskIds + if (nextTasks.isEmpty()) { + throw RuntimeException("Circular dependency detected in task breakdown") + } + taskIds.addAll(nextTasks.keys) + nextTasks.keys.forEach { taskMap.remove(it) } } + return taskIds + } - val isWindows = System.getProperty("os.name").lowercase(Locale.getDefault()).contains("windows") - private fun sanitizeForMermaid(input: String) = input - .replace(" ", "_") - .replace("\"", "\\\"") - .replace("[", "\\[") - .replace("]", "\\]") - .replace("(", "\\(") - .replace(")", "\\)") - .let { "`$it`" } + val isWindows = System.getProperty("os.name").lowercase(Locale.getDefault()).contains("windows") + private fun sanitizeForMermaid(input: String) = input + .replace(" ", "_") + .replace("\"", "\\\"") + .replace("[", "\\[") + .replace("]", "\\]") + .replace("(", "\\(") + .replace(")", "\\)") + .let { "`$it`" } - private fun escapeMermaidCharacters(input: String) = input - .replace("\"", "\\\"") - .let { '"' + it + '"' } + private fun escapeMermaidCharacters(input: String) = input + .replace("\"", "\\\"") + .let { '"' + it + '"' } - // Cache for memoizing buildMermaidGraph results - private val mermaidGraphCache = ConcurrentHashMap() - private val mermaidExceptionCache = ConcurrentHashMap() + // Cache for memoizing buildMermaidGraph results + private val mermaidGraphCache = ConcurrentHashMap() + private val mermaidExceptionCache = ConcurrentHashMap() - fun buildMermaidGraph(subTasks: Map): String { - // Generate a unique key based on the subTasks map - val cacheKey = JsonUtil.toJson(subTasks) - // Return cached result if available - mermaidGraphCache[cacheKey]?.let { return it } - mermaidExceptionCache[cacheKey]?.let { throw it } - try { - val graphBuilder = StringBuilder("graph TD;\n") - subTasks.forEach { (taskId, task) -> - val sanitizedTaskId = sanitizeForMermaid(taskId) - val taskType = task.task_type ?: "Unknown" - val escapedDescription = escapeMermaidCharacters(task.task_description ?: "") - val style = when (task.state) { - TaskState.Completed -> ":::completed" - TaskState.InProgress -> ":::inProgress" - else -> ":::$taskType" - } - graphBuilder.append(" ${sanitizedTaskId}[$escapedDescription]$style;\n") - task.task_dependencies?.forEach { dependency -> - val sanitizedDependency = sanitizeForMermaid(dependency) - graphBuilder.append(" $sanitizedDependency --> ${sanitizedTaskId};\n") - } - } - graphBuilder.append(" classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef NewFile fill:lightblue,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef EditFile fill:lightgreen,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef Documentation fill:lightyellow,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef Inquiry fill:orange,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef TaskPlanning fill:lightgrey,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef completed fill:#90EE90,stroke:#333,stroke-width:2px;\n") - graphBuilder.append(" classDef inProgress fill:#FFA500,stroke:#333,stroke-width:2px;\n") - val graph = graphBuilder.toString() - mermaidGraphCache[cacheKey] = graph - return graph - } catch (e: Exception) { - mermaidExceptionCache[cacheKey] = e - throw e + fun buildMermaidGraph(subTasks: Map): String { + // Generate a unique key based on the subTasks map + val cacheKey = JsonUtil.toJson(subTasks) + // Return cached result if available + mermaidGraphCache[cacheKey]?.let { return it } + mermaidExceptionCache[cacheKey]?.let { throw it } + try { + val graphBuilder = StringBuilder("graph TD;\n") + subTasks.forEach { (taskId, task) -> + val sanitizedTaskId = sanitizeForMermaid(taskId) + val taskType = task.task_type ?: "Unknown" + val escapedDescription = escapeMermaidCharacters(task.task_description ?: "") + val style = when (task.state) { + TaskState.Completed -> ":::completed" + TaskState.InProgress -> ":::inProgress" + else -> ":::$taskType" } + graphBuilder.append(" ${sanitizedTaskId}[$escapedDescription]$style;\n") + task.task_dependencies?.forEach { dependency -> + val sanitizedDependency = sanitizeForMermaid(dependency) + graphBuilder.append(" $sanitizedDependency --> ${sanitizedTaskId};\n") + } + } + graphBuilder.append(" classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef NewFile fill:lightblue,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef EditFile fill:lightgreen,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef Documentation fill:lightyellow,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef Inquiry fill:orange,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef TaskPlanning fill:lightgrey,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef completed fill:#90EE90,stroke:#333,stroke-width:2px;\n") + graphBuilder.append(" classDef inProgress fill:#FFA500,stroke:#333,stroke-width:2px;\n") + val graph = graphBuilder.toString() + mermaidGraphCache[cacheKey] = graph + return graph + } catch (e: Exception) { + mermaidExceptionCache[cacheKey] = e + throw e } + } - fun filterPlan(retries: Int = 3, fn: () -> Map?): Map? { - val obj = fn() ?: emptyMap() - val tasksByID = obj.filter { (k, v) -> - when { - v.task_type == TaskType.TaskPlanning.name && v.task_dependencies.isNullOrEmpty() -> - if (retries <= 0) { - log.warn("TaskPlanning task $k has no dependencies: " + JsonUtil.toJson(obj)) - true - } else { - log.info("TaskPlanning task $k has no dependencies") - return filterPlan(retries - 1, fn) - } + fun filterPlan(retries: Int = 3, fn: () -> Map?): Map? { + val obj = fn() ?: emptyMap() + val tasksByID = obj.filter { (k, v) -> + when { + v.task_type == TaskType.TaskPlanning.name && v.task_dependencies.isNullOrEmpty() -> + if (retries <= 0) { + log.warn("TaskPlanning task $k has no dependencies: " + JsonUtil.toJson(obj)) + true + } else { + log.info("TaskPlanning task $k has no dependencies") + return filterPlan(retries - 1, fn) + } - else -> true - } - } - tasksByID.forEach { - it.value.task_dependencies = it.value.task_dependencies?.filter { it in tasksByID.keys } - it.value.state = TaskState.Pending - } - try { - executionOrder(tasksByID) - } catch (e: RuntimeException) { - if (retries <= 0) { - log.warn("Error filtering plan: " + JsonUtil.toJson(obj), e) - throw e - } else { - log.info("Circular dependency detected in task breakdown") - return filterPlan(retries - 1, fn) - } - } - return if (tasksByID.size == obj.size) { - obj - } else filterPlan { - tasksByID - } + else -> true + } + } + tasksByID.forEach { + it.value.task_dependencies = it.value.task_dependencies?.filter { it in tasksByID.keys } + it.value.state = TaskState.Pending } + try { + executionOrder(tasksByID) + } catch (e: RuntimeException) { + if (retries <= 0) { + log.warn("Error filtering plan: " + JsonUtil.toJson(obj), e) + throw e + } else { + log.info("Circular dependency detected in task breakdown") + return filterPlan(retries - 1, fn) + } + } + return if (tasksByID.size == obj.size) { + obj + } else filterPlan { + tasksByID + } + } - fun getAllDependencies( - subPlanTask: PlanTaskBase, - subTasks: Map, - visited: MutableSet - ): List { - val dependencies = subPlanTask.task_dependencies?.toMutableList() ?: mutableListOf() - subPlanTask.task_dependencies?.forEach { dep -> - if (dep in visited) return@forEach - val subTask = subTasks[dep] - if (subTask != null) { - visited.add(dep) - dependencies.addAll(getAllDependencies(subTask, subTasks, visited)) - } - } - return dependencies + fun getAllDependencies( + subPlanTask: PlanTaskBase, + subTasks: Map, + visited: MutableSet + ): List { + val dependencies = subPlanTask.task_dependencies?.toMutableList() ?: mutableListOf() + subPlanTask.task_dependencies?.forEach { dep -> + if (dep in visited) return@forEach + val subTask = subTasks[dep] + if (subTask != null) { + visited.add(dep) + dependencies.addAll(getAllDependencies(subTask, subTasks, visited)) + } } + return dependencies + } - val log = org.slf4j.LoggerFactory.getLogger(PlanUtil::class.java) + val log = org.slf4j.LoggerFactory.getLogger(PlanUtil::class.java) } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/Planner.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/Planner.kt index 106c966d..0ae84425 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/Planner.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/Planner.kt @@ -15,117 +15,117 @@ import java.util.UUID open class Planner { - open fun initialPlan( - codeFiles: Map, - files: Array, - root: Path, - task: SessionTask, - userMessage: String, - ui: ApplicationInterface, - planSettings: PlanSettings, - api: API - ): TaskBreakdownWithPrompt { - val api = (api as ChatClient).getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - val toInput = inputFn(codeFiles, files, root) - return if (planSettings.allowBlocking) - Discussable( - task = task, - heading = MarkdownUtil.renderMarkdown(userMessage, ui = ui), - userMessage = { userMessage }, - initialResponse = { - newPlan( - api, - planSettings, - toInput(userMessage) - ) - }, - outputFn = { - try { - PlanUtil.render( - withPrompt = TaskBreakdownWithPrompt( - prompt = userMessage, - plan = it.obj, - planText = it.text - ), - ui = ui - ) - } catch (e: Throwable) { - log.warn("Error rendering task breakdown", e) - task.error(ui, e) - e.message ?: e.javaClass.simpleName - } - }, - ui = ui, - reviseResponse = { userMessages: List> -> - newPlan( - api, - planSettings, - userMessages.map { it.first }) - }, - ).call().let { - TaskBreakdownWithPrompt( - prompt = userMessage, - plan = PlanUtil.filterPlan { it.obj } ?: emptyMap(), - planText = it.text - ) - } - else newPlan( + open fun initialPlan( + codeFiles: Map, + files: Array, + root: Path, + task: SessionTask, + userMessage: String, + ui: ApplicationInterface, + planSettings: PlanSettings, + api: API + ): TaskBreakdownWithPrompt { + val api = (api as ChatClient).getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } + val toInput = inputFn(codeFiles, files, root) + return if (planSettings.allowBlocking) + Discussable( + task = task, + heading = MarkdownUtil.renderMarkdown(userMessage, ui = ui), + userMessage = { userMessage }, + initialResponse = { + newPlan( api, planSettings, toInput(userMessage) - ).let { - TaskBreakdownWithPrompt( + ) + }, + outputFn = { + try { + PlanUtil.render( + withPrompt = TaskBreakdownWithPrompt( prompt = userMessage, - plan = PlanUtil.filterPlan { it.obj } ?: emptyMap(), + plan = it.obj, planText = it.text + ), + ui = ui ) - } + } catch (e: Throwable) { + log.warn("Error rendering task breakdown", e) + task.error(ui, e) + e.message ?: e.javaClass.simpleName + } + }, + ui = ui, + reviseResponse = { userMessages: List> -> + newPlan( + api, + planSettings, + userMessages.map { it.first }) + }, + ).call().let { + TaskBreakdownWithPrompt( + prompt = userMessage, + plan = PlanUtil.filterPlan { it.obj } ?: emptyMap(), + planText = it.text + ) + } + else newPlan( + api, + planSettings, + toInput(userMessage) + ).let { + TaskBreakdownWithPrompt( + prompt = userMessage, + plan = PlanUtil.filterPlan { it.obj } ?: emptyMap(), + planText = it.text + ) } + } - open fun newPlan( - api: API, - planSettings: PlanSettings, - inStrings: List - ): ParsedResponse> { - val planningActor = planSettings.planningActor() - return planningActor.respond( - messages = planningActor.chatMessages(inStrings), - input = inStrings, - api = api - ).map(Map::class.java) { it.tasksByID ?: emptyMap() } as ParsedResponse> - } + open fun newPlan( + api: API, + planSettings: PlanSettings, + inStrings: List + ): ParsedResponse> { + val planningActor = planSettings.planningActor() + return planningActor.respond( + messages = planningActor.chatMessages(inStrings), + input = inStrings, + api = api + ).map(Map::class.java) { it.tasksByID ?: emptyMap() } as ParsedResponse> + } - open fun inputFn( - codeFiles: Map, - files: Array, - root: Path - ) = { str: String -> - listOf( - if (!codeFiles.all { it.key.toFile().isFile } || codeFiles.size > 2) """ + open fun inputFn( + codeFiles: Map, + files: Array, + root: Path + ) = { str: String -> + listOf( + if (!codeFiles.all { it.key.toFile().isFile } || codeFiles.size > 2) """ |Files: |${codeFiles.keys.joinToString("\n") { "* $it" }} """.trimMargin() else { - files.joinToString("\n\n") { - val path = root.relativize(it.toPath()) - """ + files.joinToString("\n\n") { + val path = root.relativize(it.toPath()) + """ |## $path | |${(codeFiles[path] ?: "").let { "$TRIPLE_TILDE\n${it/*.indent(" ")*/}\n$TRIPLE_TILDE" }} """.trimMargin() - } - }, - str - ) - } + } + }, + str + ) + } - companion object { - private val log = LoggerFactory.getLogger(Planner::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(Planner::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanningTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanningTask.kt index efc25936..5333bc06 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanningTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/PlanningTask.kt @@ -20,28 +20,28 @@ import com.simiacryptus.util.JsonUtil import org.slf4j.LoggerFactory class PlanningTask( - planSettings: PlanSettings, - planTask: PlanningTaskData? + planSettings: PlanSettings, + planTask: PlanningTaskData? ) : AbstractTask(planSettings, planTask) { - class PlanningTaskData( - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = TaskState.Pending, - ) : PlanTaskBase( - task_type = TaskType.TaskPlanning.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class PlanningTaskData( + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = TaskState.Pending, + ) : PlanTaskBase( + task_type = TaskType.TaskPlanning.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - data class TaskBreakdownResult( - @Description("A map where each task ID is associated with its corresponding PlanTask object. Crucial for defining task relationships and information flow.") - val tasksByID: Map? = null, - ) + data class TaskBreakdownResult( + @Description("A map where each task ID is associated with its corresponding PlanTask object. Crucial for defining task relationships and information flow.") + val tasksByID: Map? = null, + ) - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ |Task Planning: | * Perform high-level planning and organization of tasks. | * Decompose the overall goal into smaller, actionable tasks based on current information, ensuring proper information flow between tasks. @@ -52,108 +52,108 @@ class PlanningTask( | * **Note**: A planning task should refine the plan based on new information, optimizing task relationships and dependencies, and should not initiate execution. | * Ensure that each task utilizes the outputs or side effects of its upstream tasks, and provides outputs or side effects for its downstream tasks. """.trimMargin() - } + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val userMessage = messages.joinToString("\n") - val newTask = agent.ui.newTask(false).apply { add(placeholder) } - fun toInput(s: String) = (messages + listOf(s)).filter { it.isNotBlank() } + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val userMessage = messages.joinToString("\n") + val newTask = agent.ui.newTask(false).apply { add(placeholder) } + fun toInput(s: String) = (messages + listOf(s)).filter { it.isNotBlank() } - val subPlan = if (planSettings.allowBlocking && !planSettings.autoFix) { - createSubPlanDiscussable(newTask, userMessage, ::toInput, api, agent.ui, planSettings).call().obj - } else { - val design = planSettings.planningActor().answer( - toInput("Expand ${planTask?.task_description ?: ""}"), - api = api - ) - render( - withPrompt = TaskBreakdownWithPrompt( - plan = filterPlan { design.obj.tasksByID } ?: emptyMap(), - planText = design.text, - prompt = userMessage - ), - ui = agent.ui - ) - design.obj - } - executeSubTasks(agent, userMessage, filterPlan { subPlan.tasksByID } ?: emptyMap(), task, api, api2) + val subPlan = if (planSettings.allowBlocking && !planSettings.autoFix) { + createSubPlanDiscussable(newTask, userMessage, ::toInput, api, agent.ui, planSettings).call().obj + } else { + val design = planSettings.planningActor().answer( + toInput("Expand ${planTask?.task_description ?: ""}"), + api = api + ) + render( + withPrompt = TaskBreakdownWithPrompt( + plan = filterPlan { design.obj.tasksByID } ?: emptyMap(), + planText = design.text, + prompt = userMessage + ), + ui = agent.ui + ) + design.obj } + executeSubTasks(agent, userMessage, filterPlan { subPlan.tasksByID } ?: emptyMap(), task, api, api2) + } - private fun createSubPlanDiscussable( - task: SessionTask, - userMessage: String, - toInput: (String) -> List, - api: API, - ui: ApplicationInterface, - planSettings: PlanSettings - ) = Discussable( - task = task, - userMessage = { "Expand ${planTask?.task_description ?: ""}" }, - heading = "", - initialResponse = { it: String -> planSettings.planningActor().answer(toInput(it), api = api) }, - outputFn = { design: ParsedResponse -> - render( - withPrompt = TaskBreakdownWithPrompt( - plan = filterPlan { design.obj.tasksByID } ?: emptyMap(), - planText = design.text, - prompt = userMessage - ), - ui = ui - ) - }, - ui = ui, - reviseResponse = { usermessages: List> -> - planSettings.planningActor().respond( - messages = usermessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray(), - input = toInput("Expand ${planTask?.task_description ?: ""}\n${JsonUtil.toJson(this)}"), - api = api - ) - }, - ) + private fun createSubPlanDiscussable( + task: SessionTask, + userMessage: String, + toInput: (String) -> List, + api: API, + ui: ApplicationInterface, + planSettings: PlanSettings + ) = Discussable( + task = task, + userMessage = { "Expand ${planTask?.task_description ?: ""}" }, + heading = "", + initialResponse = { it: String -> planSettings.planningActor().answer(toInput(it), api = api) }, + outputFn = { design: ParsedResponse -> + render( + withPrompt = TaskBreakdownWithPrompt( + plan = filterPlan { design.obj.tasksByID } ?: emptyMap(), + planText = design.text, + prompt = userMessage + ), + ui = ui + ) + }, + ui = ui, + reviseResponse = { usermessages: List> -> + planSettings.planningActor().respond( + messages = usermessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray(), + input = toInput("Expand ${planTask?.task_description ?: ""}\n${JsonUtil.toJson(this)}"), + api = api + ) + }, + ) - private fun executeSubTasks( - coordinator: PlanCoordinator, - userMessage: String, - subPlan: Map, - parentTask: SessionTask, - api: API, - api2: OpenAIClient, - ) { - val subPlanTask = coordinator.ui.newTask(false) - parentTask.add(subPlanTask.placeholder) - val planProcessingState = PlanProcessingState(subPlan.toMutableMap()) - coordinator.copy( - planSettings = coordinator.planSettings.copy( - taskSettings = coordinator.planSettings.taskSettings.toList().toTypedArray().toMap().toMutableMap().apply { - this["TaskPlanning"] = TaskSettings(enabled = false, model = null) - } - ) - ).executePlan( - task = subPlanTask, - diagramBuffer = subPlanTask.add(diagram(coordinator.ui, planProcessingState.subTasks)), - subTasks = subPlan, - diagramTask = subPlanTask, - planProcessingState = planProcessingState, - taskIdProcessingQueue = executionOrder(subPlan).toMutableList(), - pool = coordinator.pool, - userMessage = userMessage, - plan = subPlan, - api = api, - api2 = api2, - ) - subPlanTask.complete() - } + private fun executeSubTasks( + coordinator: PlanCoordinator, + userMessage: String, + subPlan: Map, + parentTask: SessionTask, + api: API, + api2: OpenAIClient, + ) { + val subPlanTask = coordinator.ui.newTask(false) + parentTask.add(subPlanTask.placeholder) + val planProcessingState = PlanProcessingState(subPlan.toMutableMap()) + coordinator.copy( + planSettings = coordinator.planSettings.copy( + taskSettings = coordinator.planSettings.taskSettings.toList().toTypedArray().toMap().toMutableMap().apply { + this["TaskPlanning"] = TaskSettings(enabled = false, model = null) + } + ) + ).executePlan( + task = subPlanTask, + diagramBuffer = subPlanTask.add(diagram(coordinator.ui, planProcessingState.subTasks)), + subTasks = subPlan, + diagramTask = subPlanTask, + planProcessingState = planProcessingState, + taskIdProcessingQueue = executionOrder(subPlan).toMutableList(), + pool = coordinator.pool, + userMessage = userMessage, + plan = subPlan, + api = api, + api2 = api2, + ) + subPlanTask.complete() + } - companion object { - private val log = LoggerFactory.getLogger(PlanningTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(PlanningTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/RunShellCommandTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/RunShellCommandTask.kt index 60d014f0..15ae187c 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/RunShellCommandTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/RunShellCommandTask.kt @@ -15,106 +15,106 @@ import java.util.concurrent.Semaphore import kotlin.reflect.KClass class RunShellCommandTask( - planSettings: PlanSettings, - planTask: RunShellCommandTaskData? + planSettings: PlanSettings, + planTask: RunShellCommandTaskData? ) : AbstractTask(planSettings, planTask) { - class RunShellCommandTaskData( - @Description("The shell command to be executed") - val command: String? = null, - @Description("The relative file path of the working directory") - val workingDir: String? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null - ) : PlanTaskBase( - task_type = TaskType.RunShellCommand.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class RunShellCommandTaskData( + @Description("The shell command to be executed") + val command: String? = null, + @Description("The relative file path of the working directory") + val workingDir: String? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null + ) : PlanTaskBase( + task_type = TaskType.RunShellCommand.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - val shellCommandActor by lazy { - CodingActor( - name = "RunShellCommand", - interpreterClass = ProcessInterpreter::class, - details = """ + val shellCommandActor by lazy { + CodingActor( + name = "RunShellCommand", + interpreterClass = ProcessInterpreter::class, + details = """ Execute the following shell command(s) and provide the output. Ensure to handle any errors or exceptions gracefully. Note: This task is for running simple and safe commands. Avoid executing commands that can cause harm to the system or compromise security. """.trimMargin(), - symbols = mapOf( - "env" to (planSettings.env ?: emptyMap()), - "workingDir" to (planTask?.workingDir?.let { File(it).absolutePath } ?: File( - planSettings.workingDir - ).absolutePath), - "language" to (planSettings.language ?: "bash"), - "command" to (planTask?.command ?: planSettings.command), - ), - model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model - ?: planSettings.defaultModel, - temperature = planSettings.temperature, - ) - } + symbols = mapOf( + "env" to (planSettings.env ?: emptyMap()), + "workingDir" to (planTask?.workingDir?.let { File(it).absolutePath } ?: File( + planSettings.workingDir + ).absolutePath), + "language" to (planSettings.language ?: "bash"), + "command" to (planTask?.command ?: planSettings.command), + ), + model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model + ?: planSettings.defaultModel, + temperature = planSettings.temperature, + ) + } - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ |RunShellCommand - Execute shell commands and provide the output | ** Specify the command to be executed, or describe the task to be performed | ** List input files/tasks to be examined when writing the command | ** Optionally specify a working directory for the command execution """.trimMargin() - } + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val semaphore = Semaphore(0) + object : CodingAgent( + api = api, + dataStorage = agent.dataStorage, + session = agent.session, + user = agent.user, + ui = agent.ui, + interpreter = shellCommandActor.interpreterClass as KClass, + symbols = shellCommandActor.symbols, + temperature = shellCommandActor.temperature, + details = shellCommandActor.details, + model = shellCommandActor.model, + mainTask = task, ) { - val semaphore = Semaphore(0) - object : CodingAgent( - api = api, - dataStorage = agent.dataStorage, - session = agent.session, - user = agent.user, - ui = agent.ui, - interpreter = shellCommandActor.interpreterClass as KClass, - symbols = shellCommandActor.symbols, - temperature = shellCommandActor.temperature, - details = shellCommandActor.details, - model = shellCommandActor.model, - mainTask = task, - ) { - override fun displayFeedback( - task: SessionTask, - request: CodingActor.CodeRequest, - response: CodingActor.CodeResult - ) { - val formText = StringBuilder() - var formHandle: StringBuilder? = null - formHandle = task.add( - """ + override fun displayFeedback( + task: SessionTask, + request: CodingActor.CodeRequest, + response: CodingActor.CodeResult + ) { + val formText = StringBuilder() + var formHandle: StringBuilder? = null + formHandle = task.add( + """
${if (!super.canPlay) "" else super.playButton(task, request, response, formText) { formHandle!! }} ${acceptButton(response)}
${super.reviseMsg(task, request, response, formText) { formHandle!! }} """.trimMargin(), className = "reply-message" - ) - formText.append(formHandle.toString()) - formHandle.toString() - task.complete() - } + ) + formText.append(formHandle.toString()) + formHandle.toString() + task.complete() + } - fun acceptButton( - response: CodingActor.CodeResult - ): String { - return ui.hrefLink("Accept", "href-link play-button") { - response.let { - """ + fun acceptButton( + response: CodingActor.CodeResult + ): String { + return ui.hrefLink("Accept", "href-link play-button") { + response.let { + """ |## Shell Command Output | |$TRIPLE_TILDE @@ -126,25 +126,25 @@ Note: This task is for running simple and safe commands. Avoid executing command |${TRIPLE_TILDE} | """.trimMargin() - }.apply { resultFn(this) } - semaphore.release() - } - } - }.apply> { - start( - codeRequest( - messages.map { it to ApiModel.Role.user } - ) - ) - } - try { - semaphore.acquire() - } catch (e: Throwable) { - log.warn("Error", e) + }.apply { resultFn(this) } + semaphore.release() } + } + }.apply> { + start( + codeRequest( + messages.map { it to ApiModel.Role.user } + ) + ) } - - companion object { - private val log = LoggerFactory.getLogger(RunShellCommandTask::class.java) + try { + semaphore.acquire() + } catch (e: Throwable) { + log.warn("Error", e) } + } + + companion object { + private val log = LoggerFactory.getLogger(RunShellCommandTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/SearchTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/SearchTask.kt new file mode 100644 index 00000000..2f079d9d --- /dev/null +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/SearchTask.kt @@ -0,0 +1,133 @@ +package com.simiacryptus.skyenet.apps.plan + +import com.simiacryptus.diff.FileValidationUtils +import com.simiacryptus.jopenai.ChatClient +import com.simiacryptus.jopenai.OpenAIClient +import com.simiacryptus.jopenai.describe.Description +import com.simiacryptus.skyenet.util.MarkdownUtil +import com.simiacryptus.skyenet.webui.session.SessionTask +import org.slf4j.LoggerFactory +import java.nio.file.FileSystems +import java.nio.file.Files +import java.util.regex.Pattern +import kotlin.streams.asSequence + +class SearchTask( + planSettings: PlanSettings, + planTask: SearchTaskData? +) : AbstractTask(planSettings, planTask) { + class SearchTaskData( + @Description("The search pattern (substring or regex) to look for in the files") + val search_pattern: String, + @Description("Whether the search pattern is a regex (true) or a substring (false)") + val is_regex: Boolean = false, + @Description("The number of context lines to include before and after each match") + val context_lines: Int = 2, + @Description("The specific files (or file patterns) to be searched") + val input_files: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = TaskType.Search.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) + + override fun promptSegment() = """ +Search - Search for patterns in files and provide results with context + ** Specify the search pattern (substring or regex) + ** Specify whether the pattern is a regex or a substring + ** Specify the number of context lines to include + ** List input files or file patterns to be searched + """.trimMargin() + + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val searchResults = performSearch() + val formattedResults = formatSearchResults(searchResults) + task.add(MarkdownUtil.renderMarkdown(formattedResults, ui = agent.ui)) + resultFn(formattedResults) + } + + private fun performSearch(): List { + val pattern = if (planTask?.is_regex == true) { + Pattern.compile(planTask.search_pattern) + } else { + Pattern.compile(Pattern.quote(planTask?.search_pattern)) + } + + return (planTask?.input_files ?: listOf()) + .flatMap { filePattern -> + val matcher = FileSystems.getDefault().getPathMatcher("glob:$filePattern") + Files.walk(root).asSequence() + .filter { path -> + matcher.matches(root.relativize(path)) && + FileValidationUtils.isLLMIncludableFile(path.toFile()) + } + .flatMap { path -> + val relativePath = root.relativize(path).toString() + val lines = Files.readAllLines(path) + lines.mapIndexed { index, line -> + if (pattern.matcher(line).find()) { + SearchResult( + file = relativePath, + lineNumber = index + 1, + matchedLine = line, + context = getContext(lines, index, planTask?.context_lines ?: 2) + ) + } else null + }.filterNotNull() + } + .toList() + } + } + + private fun getContext(lines: List, matchIndex: Int, contextLines: Int): List { + val start = (matchIndex - contextLines).coerceAtLeast(0) + val end = (matchIndex + contextLines + 1).coerceAtMost(lines.size) + return lines.subList(start, end) + } + + private fun formatSearchResults(results: List): String { + return buildString { + appendLine("# Search Results") + appendLine() + results.groupBy { it.file }.forEach { (file, fileResults) -> + appendLine("## $file") + appendLine() + fileResults.forEach { result -> + appendLine("### Line ${result.lineNumber}") + appendLine() + appendLine("```") + result.context.forEachIndexed { index, line -> + val lineNumber = result.lineNumber - (result.context.size / 2) + index + val prefix = if (lineNumber == result.lineNumber) ">" else " " + appendLine("$prefix $lineNumber: $line") + } + appendLine("```") + appendLine() + } + } + } + } + + data class SearchResult( + val file: String, + val lineNumber: Int, + val matchedLine: String, + val context: List + ) + + companion object { + private val log = LoggerFactory.getLogger(SearchTask::class.java) + } +} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskBreakdownWithPrompt.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskBreakdownWithPrompt.kt index 4d29fd45..48e0dd62 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskBreakdownWithPrompt.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskBreakdownWithPrompt.kt @@ -1,7 +1,7 @@ package com.simiacryptus.skyenet.apps.plan data class TaskBreakdownWithPrompt( - val prompt: String, - val plan: Map, - val planText: String + val prompt: String, + val plan: Map, + val planText: String ) \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskType.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskType.kt index d5bfe8ec..89aaccd0 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskType.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/TaskType.kt @@ -10,6 +10,7 @@ import com.simiacryptus.jopenai.describe.Description import com.simiacryptus.skyenet.apps.plan.AbstractTask.TaskState import com.simiacryptus.skyenet.apps.plan.CommandAutoFixTask.CommandAutoFixTaskData import com.simiacryptus.skyenet.apps.plan.ForeachTask.ForeachTaskData +import com.simiacryptus.skyenet.apps.plan.GoogleSearchTask.GoogleSearchTaskData import com.simiacryptus.skyenet.apps.plan.PlanningTask.PlanningTaskData import com.simiacryptus.skyenet.apps.plan.RunShellCommandTask.RunShellCommandTaskData import com.simiacryptus.skyenet.apps.plan.file.* @@ -17,13 +18,14 @@ import com.simiacryptus.skyenet.apps.plan.file.CodeOptimizationTask.CodeOptimiza import com.simiacryptus.skyenet.apps.plan.file.CodeReviewTask.CodeReviewTaskData import com.simiacryptus.skyenet.apps.plan.file.DocumentationTask.DocumentationTaskData import com.simiacryptus.skyenet.apps.plan.file.FileModificationTask.FileModificationTaskData -import com.simiacryptus.skyenet.apps.plan.GoogleSearchTask.GoogleSearchTaskData import com.simiacryptus.skyenet.apps.plan.file.InquiryTask.InquiryTaskData import com.simiacryptus.skyenet.apps.plan.file.PerformanceAnalysisTask.PerformanceAnalysisTaskData import com.simiacryptus.skyenet.apps.plan.file.RefactorTask.RefactorTaskData import com.simiacryptus.skyenet.apps.plan.file.SecurityAuditTask.SecurityAuditTaskData import com.simiacryptus.skyenet.apps.plan.file.TestGenerationTask.TestGenerationTaskData import com.simiacryptus.skyenet.apps.plan.knowledge.EmbeddingSearchTask +import com.simiacryptus.skyenet.apps.plan.knowledge.KnowledgeIndexingTask +import com.simiacryptus.skyenet.apps.plan.knowledge.WebSearchAndIndexTask import com.simiacryptus.util.DynamicEnum import com.simiacryptus.util.DynamicEnumDeserializer import com.simiacryptus.util.DynamicEnumSerializer @@ -31,129 +33,133 @@ import com.simiacryptus.util.DynamicEnumSerializer @JsonDeserialize(using = TaskTypeDeserializer::class) @JsonSerialize(using = TaskTypeSerializer::class) class TaskType( - name: String, - val taskDataClass: Class + name: String, + val taskDataClass: Class ) : DynamicEnum>(name) { - companion object { - - private val taskConstructors = - mutableMapOf, (PlanSettings, PlanTaskBase?) -> AbstractTask>() - - val TaskPlanning = TaskType("TaskPlanning", PlanningTaskData::class.java) - val Inquiry = TaskType("Inquiry", InquiryTaskData::class.java) - val Search = TaskType("Search", SearchTask.SearchTaskData::class.java) - val EmbeddingSearch = TaskType("EmbeddingSearch", EmbeddingSearchTask.EmbeddingSearchTaskData::class.java) - val FileModification = TaskType("FileModification", FileModificationTaskData::class.java) - val Documentation = TaskType("Documentation", DocumentationTaskData::class.java) - val CodeReview = TaskType("CodeReview", CodeReviewTaskData::class.java) - val TestGeneration = TaskType("TestGeneration", TestGenerationTaskData::class.java) - val Optimization = TaskType("Optimization", CodeOptimizationTaskData::class.java) - val SecurityAudit = TaskType("SecurityAudit", SecurityAuditTaskData::class.java) - val PerformanceAnalysis = TaskType("PerformanceAnalysis", PerformanceAnalysisTaskData::class.java) - val RefactorTask = TaskType("RefactorTask", RefactorTaskData::class.java) - val RunShellCommand = TaskType("RunShellCommand", RunShellCommandTaskData::class.java) - val CommandAutoFix = TaskType("CommandAutoFix", CommandAutoFixTaskData::class.java) - val ForeachTask = TaskType("ForeachTask", ForeachTaskData::class.java) - val GitHubSearch = TaskType("GitHubSearch", GitHubSearchTask.GitHubSearchTaskData::class.java) - val GoogleSearch = TaskType("GoogleSearch", GoogleSearchTaskData::class.java) - val WebFetchAndTransform = TaskType("WebFetchAndTransform", WebFetchAndTransformTask.WebFetchAndTransformTaskData::class.java) - - init { - registerConstructor(CommandAutoFix) { settings, task -> CommandAutoFixTask(settings, task) } - registerConstructor(Inquiry) { settings, task -> InquiryTask(settings, task) } - registerConstructor(Search) { settings, task -> SearchTask(settings, task) } - registerConstructor(EmbeddingSearch) { settings, task -> EmbeddingSearchTask(settings, task) } - registerConstructor(FileModification) { settings, task -> FileModificationTask(settings, task) } - registerConstructor(Documentation) { settings, task -> DocumentationTask(settings, task) } - registerConstructor(RunShellCommand) { settings, task -> RunShellCommandTask(settings, task) } - registerConstructor(CodeReview) { settings, task -> CodeReviewTask(settings, task) } - registerConstructor(TestGeneration) { settings, task -> TestGenerationTask(settings, task) } - registerConstructor(Optimization) { settings, task -> CodeOptimizationTask(settings, task) } - registerConstructor(SecurityAudit) { settings, task -> SecurityAuditTask(settings, task) } - registerConstructor(PerformanceAnalysis) { settings, task -> PerformanceAnalysisTask(settings, task) } - registerConstructor(RefactorTask) { settings, task -> RefactorTask(settings, task) } - registerConstructor(ForeachTask) { settings, task -> ForeachTask(settings, task) } - registerConstructor(TaskPlanning) { settings, task -> PlanningTask(settings, task) } - registerConstructor(GitHubSearch) { settings, task -> GitHubSearchTask(settings, task) } - registerConstructor(GoogleSearch) { settings, task -> GoogleSearchTask(settings, task) } - registerConstructor(WebFetchAndTransform) { settings, task -> WebFetchAndTransformTask(settings, task) } - } - - private fun registerConstructor( - taskType: TaskType, - constructor: (PlanSettings, T?) -> AbstractTask - ) { - taskConstructors[taskType] = { settings: PlanSettings, task: PlanTaskBase? -> - constructor(settings, task as T?) - } - register(taskType) - } - - fun values() = values(TaskType::class.java) - fun getImpl( - planSettings: PlanSettings, - planTask: PlanTaskBase? - ) = getImpl(planSettings, planTask?.task_type?.let { valueOf(it) } ?: throw RuntimeException("Task type not specified"), planTask) - - fun getImpl( - planSettings: PlanSettings, - taskType: TaskType<*>, - planTask: PlanTaskBase? = null - ): AbstractTask { - if (!planSettings.getTaskSettings(taskType).enabled) { - throw DisabledTaskException(taskType) - } - val constructor = - taskConstructors[taskType] ?: throw RuntimeException("Unknown task type: ${taskType.name}") - return constructor(planSettings, planTask) - } - - fun getAvailableTaskTypes(planSettings: PlanSettings) = values().filter { - planSettings.getTaskSettings(it).enabled - } - - fun valueOf(name: String): TaskType<*> = valueOf(TaskType::class.java, name) - private fun register(taskType: TaskType<*>) = register(TaskType::class.java, taskType) + companion object { + + private val taskConstructors = + mutableMapOf, (PlanSettings, PlanTaskBase?) -> AbstractTask>() + + val TaskPlanning = TaskType("TaskPlanning", PlanningTaskData::class.java) + val Inquiry = TaskType("Inquiry", InquiryTaskData::class.java) + val Search = TaskType("Search", SearchTask.SearchTaskData::class.java) + val EmbeddingSearch = TaskType("EmbeddingSearch", EmbeddingSearchTask.EmbeddingSearchTaskData::class.java) + val FileModification = TaskType("FileModification", FileModificationTaskData::class.java) + val Documentation = TaskType("Documentation", DocumentationTaskData::class.java) + val CodeReview = TaskType("CodeReview", CodeReviewTaskData::class.java) + val TestGeneration = TaskType("TestGeneration", TestGenerationTaskData::class.java) + val Optimization = TaskType("Optimization", CodeOptimizationTaskData::class.java) + val SecurityAudit = TaskType("SecurityAudit", SecurityAuditTaskData::class.java) + val PerformanceAnalysis = TaskType("PerformanceAnalysis", PerformanceAnalysisTaskData::class.java) + val RefactorTask = TaskType("RefactorTask", RefactorTaskData::class.java) + val RunShellCommand = TaskType("RunShellCommand", RunShellCommandTaskData::class.java) + val CommandAutoFix = TaskType("CommandAutoFix", CommandAutoFixTaskData::class.java) + val ForeachTask = TaskType("ForeachTask", ForeachTaskData::class.java) + val GitHubSearch = TaskType("GitHubSearch", GitHubSearchTask.GitHubSearchTaskData::class.java) + val GoogleSearch = TaskType("GoogleSearch", GoogleSearchTaskData::class.java) + val WebFetchAndTransform = TaskType("WebFetchAndTransform", WebFetchAndTransformTask.WebFetchAndTransformTaskData::class.java) + val KnowledgeIndexing = TaskType("KnowledgeIndexing", KnowledgeIndexingTask.KnowledgeIndexingTaskData::class.java) + val WebSearchAndIndex = TaskType("WebSearchAndIndex", WebSearchAndIndexTask.WebSearchAndIndexTaskData::class.java) + + init { + registerConstructor(CommandAutoFix) { settings, task -> CommandAutoFixTask(settings, task) } + registerConstructor(Inquiry) { settings, task -> InquiryTask(settings, task) } + registerConstructor(Search) { settings, task -> SearchTask(settings, task) } + registerConstructor(EmbeddingSearch) { settings, task -> EmbeddingSearchTask(settings, task) } + registerConstructor(FileModification) { settings, task -> FileModificationTask(settings, task) } + registerConstructor(Documentation) { settings, task -> DocumentationTask(settings, task) } + registerConstructor(RunShellCommand) { settings, task -> RunShellCommandTask(settings, task) } + registerConstructor(CodeReview) { settings, task -> CodeReviewTask(settings, task) } + registerConstructor(TestGeneration) { settings, task -> TestGenerationTask(settings, task) } + registerConstructor(Optimization) { settings, task -> CodeOptimizationTask(settings, task) } + registerConstructor(SecurityAudit) { settings, task -> SecurityAuditTask(settings, task) } + registerConstructor(PerformanceAnalysis) { settings, task -> PerformanceAnalysisTask(settings, task) } + registerConstructor(RefactorTask) { settings, task -> RefactorTask(settings, task) } + registerConstructor(ForeachTask) { settings, task -> ForeachTask(settings, task) } + registerConstructor(TaskPlanning) { settings, task -> PlanningTask(settings, task) } + registerConstructor(GitHubSearch) { settings, task -> GitHubSearchTask(settings, task) } + registerConstructor(GoogleSearch) { settings, task -> GoogleSearchTask(settings, task) } + registerConstructor(WebFetchAndTransform) { settings, task -> WebFetchAndTransformTask(settings, task) } + registerConstructor(KnowledgeIndexing) { settings, task -> KnowledgeIndexingTask(settings, task) } + registerConstructor(WebSearchAndIndex) { settings, task -> WebSearchAndIndexTask(settings, task) } + } + + private fun registerConstructor( + taskType: TaskType, + constructor: (PlanSettings, T?) -> AbstractTask + ) { + taskConstructors[taskType] = { settings: PlanSettings, task: PlanTaskBase? -> + constructor(settings, task as T?) + } + register(taskType) } + + fun values() = values(TaskType::class.java) + fun getImpl( + planSettings: PlanSettings, + planTask: PlanTaskBase? + ) = getImpl(planSettings, planTask?.task_type?.let { valueOf(it) } ?: throw RuntimeException("Task type not specified"), planTask) + + fun getImpl( + planSettings: PlanSettings, + taskType: TaskType<*>, + planTask: PlanTaskBase? = null + ): AbstractTask { + if (!planSettings.getTaskSettings(taskType).enabled) { + throw DisabledTaskException(taskType) + } + val constructor = + taskConstructors[taskType] ?: throw RuntimeException("Unknown task type: ${taskType.name}") + return constructor(planSettings, planTask) + } + + fun getAvailableTaskTypes(planSettings: PlanSettings) = values().filter { + planSettings.getTaskSettings(it).enabled + } + + fun valueOf(name: String): TaskType<*> = valueOf(TaskType::class.java, name) + private fun register(taskType: TaskType<*>) = register(TaskType::class.java, taskType) + } } @JsonTypeIdResolver(PlanTaskTypeIdResolver::class) @JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, property = "task_type") abstract class PlanTaskBase( - @Description("An enumeration indicating the type of task to be executed. Must be a single value from the TaskType enum.") - val task_type: String? = null, - @Description("A detailed description of the specific task to be performed, including its role in the overall plan and its dependencies on other tasks.") - var task_description: String? = null, - @Description("A list of IDs of tasks that must be completed before this task can be executed. This defines upstream dependencies ensuring proper task order and information flow.") - var task_dependencies: List? = null, - @Description("The current execution state of the task. Important for coordinating task execution and managing dependencies.") - var state: TaskState? = null + @Description("An enumeration indicating the type of task to be executed. Must be a single value from the TaskType enum.") + val task_type: String? = null, + @Description("A detailed description of the specific task to be performed, including its role in the overall plan and its dependencies on other tasks.") + var task_description: String? = null, + @Description("A list of IDs of tasks that must be completed before this task can be executed. This defines upstream dependencies ensuring proper task order and information flow.") + var task_dependencies: List? = null, + @Description("The current execution state of the task. Important for coordinating task execution and managing dependencies.") + var state: TaskState? = null ) class PlanTaskTypeIdResolver : TypeIdResolverBase() { - override fun idFromValue(value: Any) = when (value) { - is PlanTaskBase -> if (value.task_type != null) { - value.task_type - } else { - throw IllegalArgumentException("Unknown task type") - } - - else -> throw IllegalArgumentException("Unexpected value type: ${value.javaClass}") + override fun idFromValue(value: Any) = when (value) { + is PlanTaskBase -> if (value.task_type != null) { + value.task_type + } else { + throw IllegalArgumentException("Unknown task type") } - override fun idFromValueAndType(value: Any, suggestedType: Class<*>): String { - return idFromValue(value) - } + else -> throw IllegalArgumentException("Unexpected value type: ${value.javaClass}") + } - override fun typeFromId(context: DatabindContext, id: String): JavaType { - val taskType = TaskType.valueOf(id.replace(" ", "")) - val subType = context.constructType(taskType.taskDataClass) - return subType - } + override fun idFromValueAndType(value: Any, suggestedType: Class<*>): String { + return idFromValue(value) + } - override fun getMechanism(): JsonTypeInfo.Id { - return JsonTypeInfo.Id.CUSTOM - } + override fun typeFromId(context: DatabindContext, id: String): JavaType { + val taskType = TaskType.valueOf(id.replace(" ", "")) + val subType = context.constructType(taskType.taskDataClass) + return subType + } + + override fun getMechanism(): JsonTypeInfo.Id { + return JsonTypeInfo.Id.CUSTOM + } } class TaskTypeSerializer : DynamicEnumSerializer>(TaskType::class.java) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractAnalysisTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractAnalysisTask.kt index 1469304c..ae970291 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractAnalysisTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractAnalysisTask.kt @@ -12,72 +12,72 @@ import org.slf4j.LoggerFactory import java.io.File abstract class AbstractAnalysisTask( - planSettings: PlanSettings, - planTask: T? + planSettings: PlanSettings, + planTask: T? ) : AbstractFileTask(planSettings, planTask) { - abstract val actorName: String - abstract val actorPrompt: String + abstract val actorName: String + abstract val actorPrompt: String - protected val analysisActor by lazy { - SimpleActor( - name = actorName, - prompt = actorPrompt, - model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model - ?: planSettings.defaultModel, - temperature = planSettings.temperature, - ) - } + protected val analysisActor by lazy { + SimpleActor( + name = actorName, + prompt = actorPrompt, + model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model + ?: planSettings.defaultModel, + temperature = planSettings.temperature, + ) + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val analysisResult = analysisActor.answer( - messages + listOf( - getInputFileCode(), - "${getAnalysisInstruction()}:\n${getInputFileCode()}", - ).filter { it.isNotBlank() }, api = api - ) - resultFn(analysisResult) - applyChanges(agent, task, analysisResult, api) - } + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val analysisResult = analysisActor.answer( + messages + listOf( + getInputFileCode(), + "${getAnalysisInstruction()}:\n${getInputFileCode()}", + ).filter { it.isNotBlank() }, api = api + ) + resultFn(analysisResult) + applyChanges(agent, task, analysisResult, api) + } - abstract fun getAnalysisInstruction(): String + abstract fun getAnalysisInstruction(): String - private fun applyChanges(agent: PlanCoordinator, task: SessionTask, analysisResult: String, api: API) { - val outputResult = CommandPatchApp( - root = agent.root.toFile(), - session = agent.session, - settings = PatchApp.Settings( - executable = File("dummy"), - workingDirectory = agent.root.toFile(), - exitCodeOption = "nonzero", - additionalInstructions = "", - autoFix = agent.planSettings.autoFix - ), - api = api as ChatClient, - model = agent.planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model - ?: agent.planSettings.defaultModel, - files = agent.files, - command = analysisResult - ).run( - ui = agent.ui, - task = task - ) - if (outputResult.exitCode == 0) { - task.add("${actorName} completed and suggestions have been applied successfully.") - } else { - task.add("${actorName} completed, but failed to apply suggestions. Exit code: ${outputResult.exitCode}") - } + private fun applyChanges(agent: PlanCoordinator, task: SessionTask, analysisResult: String, api: API) { + val outputResult = CommandPatchApp( + root = agent.root.toFile(), + session = agent.session, + settings = PatchApp.Settings( + executable = File("dummy"), + workingDirectory = agent.root.toFile(), + exitCodeOption = "nonzero", + additionalInstructions = "", + autoFix = agent.planSettings.autoFix + ), + api = api as ChatClient, + model = agent.planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model + ?: agent.planSettings.defaultModel, + files = agent.files, + command = analysisResult + ).run( + ui = agent.ui, + task = task + ) + if (outputResult.exitCode == 0) { + task.add("${actorName} completed and suggestions have been applied successfully.") + } else { + task.add("${actorName} completed, but failed to apply suggestions. Exit code: ${outputResult.exitCode}") } + } - companion object { - private val log = LoggerFactory.getLogger(AbstractAnalysisTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(AbstractAnalysisTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractFileTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractFileTask.kt index fda5699a..50d69bf7 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractFileTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/AbstractFileTask.kt @@ -11,61 +11,61 @@ import java.nio.file.Files import kotlin.streams.asSequence abstract class AbstractFileTask( - planSettings: PlanSettings, - planTask: T? + planSettings: PlanSettings, + planTask: T? ) : AbstractTask(planSettings, planTask) { - open class FileTaskBase( - task_type: String, - task_description: String? = null, - task_dependencies: List? = null, - @Description("The relative file paths to be used as input for the task") - val input_files: List? = null, - @Description("The relative file paths to be generated as output for the task") - val output_files: List? = null, - state: TaskState? = TaskState.Pending, - ) : PlanTaskBase( - task_type = task_type, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + open class FileTaskBase( + task_type: String, + task_description: String? = null, + task_dependencies: List? = null, + @Description("The relative file paths to be used as input for the task") + val input_files: List? = null, + @Description("The relative file paths to be generated as output for the task") + val output_files: List? = null, + state: TaskState? = TaskState.Pending, + ) : PlanTaskBase( + task_type = task_type, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - protected fun getInputFileCode(): String = - ((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())) - .flatMap { pattern: String -> - val matcher = FileSystems.getDefault().getPathMatcher("glob:$pattern") - Files.walk(root).asSequence() - .filter { path -> - matcher.matches(root.relativize(path)) && - FileValidationUtils.isLLMIncludableFile(path.toFile()) - } - .map { path -> - root.relativize(path).toString() - } - .toList() - } - .distinct() - .sortedBy { it } - .joinToString("\n\n") { relativePath -> - val file = root.resolve(relativePath).toFile() - try { - """ + protected fun getInputFileCode(): String = + ((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())) + .flatMap { pattern: String -> + val matcher = FileSystems.getDefault().getPathMatcher("glob:$pattern") + Files.walk(root).asSequence() + .filter { path -> + matcher.matches(root.relativize(path)) && + FileValidationUtils.isLLMIncludableFile(path.toFile()) + } + .map { path -> + root.relativize(path).toString() + } + .toList() + } + .distinct() + .sortedBy { it } + .joinToString("\n\n") { relativePath -> + val file = root.resolve(relativePath).toFile() + try { + """ |# $relativePath | |$TRIPLE_TILDE |${codeFiles[file.toPath()] ?: file.readText()} |$TRIPLE_TILDE """.trimMargin() - } catch (e: Throwable) { - log.warn("Error reading file: $relativePath", e) - "" - } - } + } catch (e: Throwable) { + log.warn("Error reading file: $relativePath", e) + "" + } + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(AbstractFileTask::class.java) - const val TRIPLE_TILDE = "```" + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(AbstractFileTask::class.java) + const val TRIPLE_TILDE = "```" - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeOptimizationTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeOptimizationTask.kt index b3547cc0..0930858a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeOptimizationTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeOptimizationTask.kt @@ -7,31 +7,31 @@ import com.simiacryptus.skyenet.apps.plan.file.CodeOptimizationTask.CodeOptimiza import org.slf4j.LoggerFactory class CodeOptimizationTask( - planSettings: PlanSettings, - planTask: CodeOptimizationTaskData? + planSettings: PlanSettings, + planTask: CodeOptimizationTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class CodeOptimizationTaskData( - @Description("Files to be optimized") - val filesToOptimize: List? = null, - @Description("Specific areas of focus for the optimization") - val optimizationFocus: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.Optimization.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class CodeOptimizationTaskData( + @Description("Files to be optimized") + val filesToOptimize: List? = null, + @Description("Specific areas of focus for the optimization") + val optimizationFocus: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.Optimization.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override val actorName = "CodeOptimization" - override val actorPrompt = """ + override val actorName = "CodeOptimization" + override val actorPrompt = """ Analyze the provided code and suggest optimizations to improve code quality. Focus exclusively on: |1. Code structure and organization |2. Readability improvements @@ -48,20 +48,20 @@ class CodeOptimizationTask( |Use diff format to show the proposed changes clearly. """.trimMargin() - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ |CodeOptimization - Analyze and optimize existing code for better readability, maintainability, and adherence to best practices | * Specify the files to be optimized | * Optionally provide specific areas of focus for the optimization (e.g., code structure, readability, design patterns) """.trimMargin() - } + } - override fun getAnalysisInstruction(): String { - return "Optimize the following code for better readability and maintainability" - } + override fun getAnalysisInstruction(): String { + return "Optimize the following code for better readability and maintainability" + } - companion object { - private val log = LoggerFactory.getLogger(CodeOptimizationTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(CodeOptimizationTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeReviewTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeReviewTask.kt index 2a293e9f..42c66a74 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeReviewTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/CodeReviewTask.kt @@ -7,30 +7,30 @@ import com.simiacryptus.skyenet.apps.plan.file.CodeReviewTask.CodeReviewTaskData import org.slf4j.LoggerFactory class CodeReviewTask( - planSettings: PlanSettings, - planTask: CodeReviewTaskData? + planSettings: PlanSettings, + planTask: CodeReviewTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class CodeReviewTaskData( - @Description("List of files to be reviewed") - val filesToReview: List? = null, - @Description("Specific areas of focus for the review (optional)") - val focusAreas: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.CodeReview.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class CodeReviewTaskData( + @Description("List of files to be reviewed") + val filesToReview: List? = null, + @Description("Specific areas of focus for the review (optional)") + val focusAreas: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.CodeReview.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override val actorName: String = "CodeReview" - override val actorPrompt: String = """ + override val actorName: String = "CodeReview" + override val actorPrompt: String = """ Perform a comprehensive code review for the provided code files. Analyze the code for: 1. Code quality and readability 2. Potential bugs or errors @@ -43,23 +43,23 @@ class CodeReviewTask( Format the response as a markdown document with appropriate headings and code snippets. """.trimIndent() - override fun getAnalysisInstruction(): String { - val filesToReview = planTask?.filesToReview?.joinToString(", ") ?: "all provided files" - val focusAreas = planTask?.focusAreas?.joinToString(", ") - return "Review the following code files: $filesToReview" + - if (focusAreas != null) ". Focus on these areas: $focusAreas" else "" - } + override fun getAnalysisInstruction(): String { + val filesToReview = planTask?.filesToReview?.joinToString(", ") ?: "all provided files" + val focusAreas = planTask?.focusAreas?.joinToString(", ") + return "Review the following code files: $filesToReview" + + if (focusAreas != null) ". Focus on these areas: $focusAreas" else "" + } - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ CodeReview - Perform an automated code review and provide suggestions for improvements ** Specify the files to be reviewed ** Optionally provide specific areas of focus for the review """.trimMargin() - } + } - companion object { - private val log = LoggerFactory.getLogger(CodeReviewTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(CodeReviewTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/DocumentationTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/DocumentationTask.kt index 70af1f4a..5546391b 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/DocumentationTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/DocumentationTask.kt @@ -14,39 +14,39 @@ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore class DocumentationTask( - planSettings: PlanSettings, - planTask: DocumentationTaskData? + planSettings: PlanSettings, + planTask: DocumentationTaskData? ) : AbstractFileTask(planSettings, planTask) { - class DocumentationTaskData( - @Description("List topics to document") - val topics: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.Documentation.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class DocumentationTaskData( + @Description("List topics to document") + val topics: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.Documentation.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ Documentation - Generate documentation ** List input file names and tasks to be examined ** List topics to document ** List output files to be modified or created with documentation """.trimMargin() - } + } - val documentationGeneratorActor by lazy { - SimpleActor( - name = "DocumentationGenerator", - prompt = """ + val documentationGeneratorActor by lazy { + SimpleActor( + name = "DocumentationGenerator", + prompt = """ Create detailed and clear documentation for the provided code, covering its purpose, functionality, inputs, outputs, and any assumptions or limitations. Use a structured and consistent format that facilitates easy understanding and navigation. Include code examples where applicable, and explain the rationale behind key design decisions and algorithm choices. @@ -60,84 +60,84 @@ class DocumentationTask( Include 2 lines of context before and after every change in diffs. Separate code blocks with a single blank line. """.trimMargin(), + model = planSettings.getTaskSettings(TaskType.Documentation).model ?: planSettings.defaultModel, + temperature = planSettings.temperature, + ) + } + + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + if (((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())).isEmpty()) { + task.complete("No input or output files specified") + return + } + val semaphore = Semaphore(0) + val onComplete = { + semaphore.release() + } + val process = { sb: StringBuilder -> + val itemsToDocument = planTask?.topics ?: emptyList() + val docResult = documentationGeneratorActor.answer( + messages + listOf( + getInputFileCode(), + "Items to document: ${itemsToDocument.joinToString(", ")}", + "Output files: ${planTask?.output_files?.joinToString(", ") ?: ""}" + ).filter { it.isNotBlank() }, api + ) + resultFn(docResult) + if (agent.planSettings.autoFix) { + val diffLinks = agent.ui.socketManager!!.addApplyFileDiffLinks( + root = agent.root, + response = docResult, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") + } + }, + ui = agent.ui, + api = api, + shouldAutoApply = { agent.planSettings.autoFix }, + model = planSettings.getTaskSettings(TaskType.Documentation).model ?: planSettings.defaultModel, + ) + task.complete() + onComplete() + MarkdownUtil.renderMarkdown(diffLinks + "\n\n## Auto-applied documentation changes", ui = agent.ui) + } else { + MarkdownUtil.renderMarkdown( + agent.ui.socketManager!!.addApplyFileDiffLinks( + root = agent.root, + response = docResult, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") + } + }, + ui = agent.ui, + api = api, model = planSettings.getTaskSettings(TaskType.Documentation).model ?: planSettings.defaultModel, - temperature = planSettings.temperature, + ) + acceptButtonFooter(agent.ui) { + task.complete() + onComplete() + }, ui = agent.ui ) + } } - - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - if (((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())).isEmpty()) { - task.complete("No input or output files specified") - return - } - val semaphore = Semaphore(0) - val onComplete = { - semaphore.release() - } - val process = { sb: StringBuilder -> - val itemsToDocument = planTask?.topics ?: emptyList() - val docResult = documentationGeneratorActor.answer( - messages + listOf( - getInputFileCode(), - "Items to document: ${itemsToDocument.joinToString(", ")}", - "Output files: ${planTask?.output_files?.joinToString(", ") ?: ""}" - ).filter { it.isNotBlank() }, api - ) - resultFn(docResult) - if (agent.planSettings.autoFix) { - val diffLinks = agent.ui.socketManager!!.addApplyFileDiffLinks( - root = agent.root, - response = docResult, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - task.complete("$path Updated") - } - }, - ui = agent.ui, - api = api, - shouldAutoApply = { agent.planSettings.autoFix }, - model = planSettings.getTaskSettings(TaskType.Documentation).model ?: planSettings.defaultModel, - ) - task.complete() - onComplete() - MarkdownUtil.renderMarkdown(diffLinks + "\n\n## Auto-applied documentation changes", ui = agent.ui) - } else { - MarkdownUtil.renderMarkdown( - agent.ui.socketManager!!.addApplyFileDiffLinks( - root = agent.root, - response = docResult, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - task.complete("$path Updated") - } - }, - ui = agent.ui, - api = api, - model = planSettings.getTaskSettings(TaskType.Documentation).model ?: planSettings.defaultModel, - ) + acceptButtonFooter(agent.ui) { - task.complete() - onComplete() - }, ui = agent.ui - ) - } - } - Retryable(agent.ui, task = task, process = process) - try { - semaphore.acquire() - } catch (e: Throwable) { - log.warn("Error", e) - } + Retryable(agent.ui, task = task, process = process) + try { + semaphore.acquire() + } catch (e: Throwable) { + log.warn("Error", e) } + } - companion object { - private val log = LoggerFactory.getLogger(DocumentationTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(DocumentationTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/FileModificationTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/FileModificationTask.kt index 392be5ab..a4514e59 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/FileModificationTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/FileModificationTask.kt @@ -14,30 +14,30 @@ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore class FileModificationTask( - planSettings: PlanSettings, - planTask: FileModificationTaskData? + planSettings: PlanSettings, + planTask: FileModificationTaskData? ) : AbstractFileTask(planSettings, planTask) { - class FileModificationTaskData( - input_files: List? = null, - output_files: List? = null, - @Description("Specific modifications to be made to the files") - val modifications: Any? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.FileModification.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class FileModificationTaskData( + input_files: List? = null, + output_files: List? = null, + @Description("Specific modifications to be made to the files") + val modifications: Any? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.FileModification.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - val fileModificationActor by lazy { - SimpleActor( - name = "FileModification", - prompt = """ + val fileModificationActor by lazy { + SimpleActor( + name = "FileModification", + prompt = """ Generate patches for existing files or create new files based on the given requirements and context. For existing files: Ensure modifications are efficient, maintain readability, and adhere to coding standards. @@ -81,89 +81,89 @@ class FileModificationTask( |} $TRIPLE_TILDE """.trimMargin(), - model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, - temperature = planSettings.temperature, - ) - } + model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, + temperature = planSettings.temperature, + ) + } - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ FileModification - Modify existing files or create new files ** For each file, specify the relative file path and the goal of the modification or creation ** List input files/tasks to be examined when designing the modifications or new files """.trimMargin() - } + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - if (((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())).isEmpty()) { - task.complete("CONFIGURATION ERROR: No input files specified") - resultFn("CONFIGURATION ERROR: No input files specified") - return - } - val semaphore = Semaphore(0) - val onComplete = { semaphore.release() } - val process = { sb: StringBuilder -> - val codeResult = fileModificationActor.answer( - (messages + listOf( - getInputFileCode(), - this.planTask?.task_description ?: "", - )).filter { it.isNotBlank() }, api - ) - resultFn(codeResult) - if (agent.planSettings.autoFix) { - val diffLinks = agent.ui.socketManager!!.addApplyFileDiffLinks( - root = agent.root, - response = codeResult, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - task.complete("$path Updated") - } - }, - ui = agent.ui, - api = api, - shouldAutoApply = { agent.planSettings.autoFix }, - model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, - ) - task.complete() - onComplete() - renderMarkdown(diffLinks + "\n\n## Auto-applied changes", ui = agent.ui) - } else { - renderMarkdown( - agent.ui.socketManager!!.addApplyFileDiffLinks( - root = agent.root, - response = codeResult, - handle = { newCodeMap -> - newCodeMap.forEach { (path, newCode) -> - task.complete("$path Updated") - } - }, - ui = agent.ui, - api = api, - model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, - ) + acceptButtonFooter(agent.ui) { - task.complete() - onComplete() - }, ui = agent.ui - ) + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + if (((planTask?.input_files ?: listOf()) + (planTask?.output_files ?: listOf())).isEmpty()) { + task.complete("CONFIGURATION ERROR: No input files specified") + resultFn("CONFIGURATION ERROR: No input files specified") + return + } + val semaphore = Semaphore(0) + val onComplete = { semaphore.release() } + val process = { sb: StringBuilder -> + val codeResult = fileModificationActor.answer( + (messages + listOf( + getInputFileCode(), + this.planTask?.task_description ?: "", + )).filter { it.isNotBlank() }, api + ) + resultFn(codeResult) + if (agent.planSettings.autoFix) { + val diffLinks = agent.ui.socketManager!!.addApplyFileDiffLinks( + root = agent.root, + response = codeResult, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") } - } - Retryable(agent.ui, task = task, process = process) - try { - semaphore.acquire() - } catch (e: Throwable) { - log.warn("Error", e) - } + }, + ui = agent.ui, + api = api, + shouldAutoApply = { agent.planSettings.autoFix }, + model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, + ) + task.complete() + onComplete() + renderMarkdown(diffLinks + "\n\n## Auto-applied changes", ui = agent.ui) + } else { + renderMarkdown( + agent.ui.socketManager!!.addApplyFileDiffLinks( + root = agent.root, + response = codeResult, + handle = { newCodeMap -> + newCodeMap.forEach { (path, newCode) -> + task.complete("$path Updated") + } + }, + ui = agent.ui, + api = api, + model = planSettings.getTaskSettings(TaskType.FileModification).model ?: planSettings.defaultModel, + ) + acceptButtonFooter(agent.ui) { + task.complete() + onComplete() + }, ui = agent.ui + ) + } } - - companion object { - private val log = LoggerFactory.getLogger(FileModificationTask::class.java) + Retryable(agent.ui, task = task, process = process) + try { + semaphore.acquire() + } catch (e: Throwable) { + log.warn("Error", e) } + } + + companion object { + private val log = LoggerFactory.getLogger(FileModificationTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/InquiryTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/InquiryTask.kt index be448660..af882ddb 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/InquiryTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/InquiryTask.kt @@ -21,27 +21,27 @@ import java.util.concurrent.atomic.AtomicReference import kotlin.streams.asSequence class InquiryTask( - planSettings: PlanSettings, - planTask: InquiryTaskData? + planSettings: PlanSettings, + planTask: InquiryTaskData? ) : AbstractTask(planSettings, planTask) { - class InquiryTaskData( - @Description("The specific questions or topics to be addressed in the inquiry") - val inquiry_questions: List? = null, - @Description("The goal or purpose of the inquiry") - val inquiry_goal: String? = null, - @Description("The specific files (or file patterns) to be used as input for the task") - val input_files: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null, - ) : PlanTaskBase( - task_type = TaskType.Inquiry.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class InquiryTaskData( + @Description("The specific questions or topics to be addressed in the inquiry") + val inquiry_questions: List? = null, + @Description("The goal or purpose of the inquiry") + val inquiry_goal: String? = null, + @Description("The specific files (or file patterns) to be used as input for the task") + val input_files: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = TaskType.Inquiry.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - override fun promptSegment() = if (planSettings.allowBlocking) """ + override fun promptSegment() = if (planSettings.allowBlocking) """ |Inquiry - Answer questions by reading in files and providing a summary that can be discussed with and approved by the user | ** Specify the questions and the goal of the inquiry | ** List input files to be examined when answering the questions @@ -51,10 +51,10 @@ class InquiryTask( | ** List input files to be examined when answering the questions """.trimMargin() - private val inquiryActor by lazy { - SimpleActor( - name = "Inquiry", - prompt = """ + private val inquiryActor by lazy { + SimpleActor( + name = "Inquiry", + prompt = """ Create code for a new file that fulfills the specified requirements and context. Given a detailed user request, break it down into smaller, actionable tasks suitable for software development. Compile comprehensive information and insights on the specified topic. @@ -63,109 +63,109 @@ class InquiryTask( When generating insights, consider the existing project context and focus on information that is directly relevant and applicable. Focus on generating insights and information that support the task types available in the system (${ - planSettings.taskSettings.filter { it.value.enabled }.keys.joinToString(", ") - }). + planSettings.taskSettings.filter { it.value.enabled }.keys.joinToString(", ") + }). This will ensure that the inquiries are tailored to assist in the planning and execution of tasks within the system's framework. """.trimMargin(), - model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model ?: planSettings.defaultModel, - temperature = planSettings.temperature, - ) - } + model = planSettings.getTaskSettings(TaskType.valueOf(planTask?.task_type!!)).model ?: planSettings.defaultModel, + temperature = planSettings.temperature, + ) + } - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { - val toInput = { it: String -> - messages + listOf( - getInputFileCode(), - it, - ).filter { it.isNotBlank() } - } + val toInput = { it: String -> + messages + listOf( + getInputFileCode(), + it, + ).filter { it.isNotBlank() } + } - val inquiryResult = if (planSettings.allowBlocking) Discussable( - task = task, - userMessage = { - "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ - planTask?.inquiry_questions?.joinToString( - "\n" - ) - }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" - }, - heading = "", - initialResponse = { it: String -> inquiryActor.answer(toInput(it), api = api) }, - outputFn = { design: String -> - MarkdownUtil.renderMarkdown(design, ui = agent.ui) - }, - ui = agent.ui, - reviseResponse = { usermessages: List> -> - val inStr = "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ - planTask?.inquiry_questions?.joinToString("\n") - }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" - val messages = usermessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } - .toTypedArray() - inquiryActor.respond( - messages = messages, - input = toInput(inStr), - api = api - ) - }, - atomicRef = AtomicReference(), - semaphore = Semaphore(0), - ).call() else inquiryActor.answer( - toInput( - "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ - planTask?.inquiry_questions?.joinToString( - "\n" - ) - }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" - ), - api = api - ).apply { - task.add(MarkdownUtil.renderMarkdown(this, ui = agent.ui)) - } - resultFn(inquiryResult) + val inquiryResult = if (planSettings.allowBlocking) Discussable( + task = task, + userMessage = { + "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ + planTask?.inquiry_questions?.joinToString( + "\n" + ) + }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" + }, + heading = "", + initialResponse = { it: String -> inquiryActor.answer(toInput(it), api = api) }, + outputFn = { design: String -> + MarkdownUtil.renderMarkdown(design, ui = agent.ui) + }, + ui = agent.ui, + reviseResponse = { usermessages: List> -> + val inStr = "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ + planTask?.inquiry_questions?.joinToString("\n") + }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" + val messages = usermessages.map { ApiModel.ChatMessage(it.second, it.first.toContentList()) } + .toTypedArray() + inquiryActor.respond( + messages = messages, + input = toInput(inStr), + api = api + ) + }, + atomicRef = AtomicReference(), + semaphore = Semaphore(0), + ).call() else inquiryActor.answer( + toInput( + "Expand ${this.planTask?.task_description ?: ""}\nQuestions: ${ + planTask?.inquiry_questions?.joinToString( + "\n" + ) + }\nGoal: ${planTask?.inquiry_goal}\n${JsonUtil.toJson(data = this)}" + ), + api = api + ).apply { + task.add(MarkdownUtil.renderMarkdown(this, ui = agent.ui)) } + resultFn(inquiryResult) + } - private fun getInputFileCode(): String = - ((planTask?.input_files ?: listOf())) - .flatMap { pattern: String -> - val matcher = FileSystems.getDefault().getPathMatcher("glob:$pattern") - Files.walk(root).asSequence() - .filter { path -> - matcher.matches(root.relativize(path)) && - FileValidationUtils.isLLMIncludableFile(path.toFile()) - } - .map { path -> - root.relativize(path).toString() - } - .toList() - } - .distinct() - .sortedBy { it } - .joinToString("\n\n") { relativePath -> - val file = root.resolve(relativePath).toFile() - try { - """ + private fun getInputFileCode(): String = + ((planTask?.input_files ?: listOf())) + .flatMap { pattern: String -> + val matcher = FileSystems.getDefault().getPathMatcher("glob:$pattern") + Files.walk(root).asSequence() + .filter { path -> + matcher.matches(root.relativize(path)) && + FileValidationUtils.isLLMIncludableFile(path.toFile()) + } + .map { path -> + root.relativize(path).toString() + } + .toList() + } + .distinct() + .sortedBy { it } + .joinToString("\n\n") { relativePath -> + val file = root.resolve(relativePath).toFile() + try { + """ |# $relativePath | |${AbstractFileTask.TRIPLE_TILDE} |${codeFiles[file.toPath()] ?: file.readText()} |${AbstractFileTask.TRIPLE_TILDE} """.trimMargin() - } catch (e: Throwable) { - log.warn("Error reading file: $relativePath", e) - "" - } - } + } catch (e: Throwable) { + log.warn("Error reading file: $relativePath", e) + "" + } + } - companion object { - private val log = LoggerFactory.getLogger(InquiryTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(InquiryTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/PerformanceAnalysisTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/PerformanceAnalysisTask.kt index 138caee2..932419d2 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/PerformanceAnalysisTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/PerformanceAnalysisTask.kt @@ -7,30 +7,30 @@ import com.simiacryptus.skyenet.apps.plan.file.PerformanceAnalysisTask.Performan import org.slf4j.LoggerFactory class PerformanceAnalysisTask( - planSettings: PlanSettings, - planTask: PerformanceAnalysisTaskData? + planSettings: PlanSettings, + planTask: PerformanceAnalysisTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class PerformanceAnalysisTaskData( - @Description("Files to be analyzed for performance issues") - val files_to_analyze: List? = null, - @Description("Specific areas of focus for the analysis (e.g., time complexity, memory usage, I/O operations)") - val analysis_focus: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null, - ) : FileTaskBase( - task_type = TaskType.PerformanceAnalysis.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class PerformanceAnalysisTaskData( + @Description("Files to be analyzed for performance issues") + val files_to_analyze: List? = null, + @Description("Specific areas of focus for the analysis (e.g., time complexity, memory usage, I/O operations)") + val analysis_focus: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null, + ) : FileTaskBase( + task_type = TaskType.PerformanceAnalysis.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override val actorName = "PerformanceAnalysis" - override val actorPrompt = """ + override val actorName = "PerformanceAnalysis" + override val actorPrompt = """ Analyze the provided code for performance issues and bottlenecks. Focus exclusively on: 1. Time complexity of algorithms 2. Memory usage and potential leaks @@ -47,24 +47,24 @@ Format the response as a markdown document with appropriate headings and code sn Do not provide code changes, focus on analysis and recommendations. """.trimIndent() - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ PerformanceAnalysis - Analyze code for performance issues and suggest improvements ** Specify the files to be analyzed ** Optionally provide specific areas of focus for the analysis (e.g., time complexity, memory usage, I/O operations) """.trimMargin() - } + } - fun getFiles(): List { - return planTask?.files_to_analyze ?: emptyList() - } + fun getFiles(): List { + return planTask?.files_to_analyze ?: emptyList() + } - override fun getAnalysisInstruction(): String { - return "Analyze the following code for performance issues and provide a detailed report" - } + override fun getAnalysisInstruction(): String { + return "Analyze the following code for performance issues and provide a detailed report" + } - companion object { - private val log = LoggerFactory.getLogger(PerformanceAnalysisTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(PerformanceAnalysisTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/RefactorTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/RefactorTask.kt index 32ffcb6a..c6a4a397 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/RefactorTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/RefactorTask.kt @@ -7,30 +7,30 @@ import com.simiacryptus.skyenet.apps.plan.file.RefactorTask.RefactorTaskData import org.slf4j.LoggerFactory class RefactorTask( - planSettings: PlanSettings, - planTask: RefactorTaskData? + planSettings: PlanSettings, + planTask: RefactorTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class RefactorTaskData( - @Description("List of files to be refactored") - val filesToRefactor: List? = null, - @Description("Specific areas of focus for the refactoring (e.g., modularity, design patterns, naming conventions)") - val refactoringFocus: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.RefactorTask.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class RefactorTaskData( + @Description("List of files to be refactored") + val filesToRefactor: List? = null, + @Description("Specific areas of focus for the refactoring (e.g., modularity, design patterns, naming conventions)") + val refactoringFocus: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.RefactorTask.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override val actorName: String = "Refactor" - override val actorPrompt: String = """ + override val actorName: String = "Refactor" + override val actorPrompt: String = """ Analyze the provided code and suggest refactoring to improve code structure, readability, and maintainability. Focus on: 1. Improving code organization 2. Reducing code duplication @@ -48,17 +48,17 @@ Format the response as a markdown document with appropriate headings and code sn Use diff format to show the proposed changes clearly. """.trimIndent() - override fun getAnalysisInstruction(): String = "Refactor the following code" + override fun getAnalysisInstruction(): String = "Refactor the following code" - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ RefactorTask - Analyze and refactor existing code to improve structure, readability, and maintainability ** Specify the files to be refactored ** Optionally provide specific areas of focus for the refactoring (e.g., modularity, design patterns, naming conventions) """.trimMargin() - } + } - companion object { - private val log = LoggerFactory.getLogger(RefactorTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(RefactorTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SearchTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SearchTask.kt deleted file mode 100644 index b9ae95a0..00000000 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SearchTask.kt +++ /dev/null @@ -1,134 +0,0 @@ -package com.simiacryptus.skyenet.apps.plan.file - -import com.simiacryptus.diff.FileValidationUtils -import com.simiacryptus.jopenai.ChatClient -import com.simiacryptus.jopenai.OpenAIClient -import com.simiacryptus.jopenai.describe.Description -import com.simiacryptus.skyenet.apps.plan.* -import com.simiacryptus.skyenet.util.MarkdownUtil -import com.simiacryptus.skyenet.webui.session.SessionTask -import org.slf4j.LoggerFactory -import java.nio.file.FileSystems -import java.nio.file.Files -import java.util.regex.Pattern -import kotlin.streams.asSequence - -class SearchTask( - planSettings: PlanSettings, - planTask: SearchTaskData? -) : AbstractTask(planSettings, planTask) { - class SearchTaskData( - @Description("The search pattern (substring or regex) to look for in the files") - val search_pattern: String, - @Description("Whether the search pattern is a regex (true) or a substring (false)") - val is_regex: Boolean = false, - @Description("The number of context lines to include before and after each match") - val context_lines: Int = 2, - @Description("The specific files (or file patterns) to be searched") - val input_files: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null, - ) : PlanTaskBase( - task_type = TaskType.Search.name, - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) - - override fun promptSegment() = """ -Search - Search for patterns in files and provide results with context - ** Specify the search pattern (substring or regex) - ** Specify whether the pattern is a regex or a substring - ** Specify the number of context lines to include - ** List input files or file patterns to be searched - """.trimMargin() - - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val searchResults = performSearch() - val formattedResults = formatSearchResults(searchResults) - task.add(MarkdownUtil.renderMarkdown(formattedResults, ui = agent.ui)) - resultFn(formattedResults) - } - - private fun performSearch(): List { - val pattern = if (planTask?.is_regex == true) { - Pattern.compile(planTask.search_pattern) - } else { - Pattern.compile(Pattern.quote(planTask?.search_pattern)) - } - - return (planTask?.input_files ?: listOf()) - .flatMap { filePattern -> - val matcher = FileSystems.getDefault().getPathMatcher("glob:$filePattern") - Files.walk(root).asSequence() - .filter { path -> - matcher.matches(root.relativize(path)) && - FileValidationUtils.isLLMIncludableFile(path.toFile()) - } - .flatMap { path -> - val relativePath = root.relativize(path).toString() - val lines = Files.readAllLines(path) - lines.mapIndexed { index, line -> - if (pattern.matcher(line).find()) { - SearchResult( - file = relativePath, - lineNumber = index + 1, - matchedLine = line, - context = getContext(lines, index, planTask?.context_lines ?: 2) - ) - } else null - }.filterNotNull() - } - .toList() - } - } - - private fun getContext(lines: List, matchIndex: Int, contextLines: Int): List { - val start = (matchIndex - contextLines).coerceAtLeast(0) - val end = (matchIndex + contextLines + 1).coerceAtMost(lines.size) - return lines.subList(start, end) - } - - private fun formatSearchResults(results: List): String { - return buildString { - appendLine("# Search Results") - appendLine() - results.groupBy { it.file }.forEach { (file, fileResults) -> - appendLine("## $file") - appendLine() - fileResults.forEach { result -> - appendLine("### Line ${result.lineNumber}") - appendLine() - appendLine("```") - result.context.forEachIndexed { index, line -> - val lineNumber = result.lineNumber - (result.context.size / 2) + index - val prefix = if (lineNumber == result.lineNumber) ">" else " " - appendLine("$prefix $lineNumber: $line") - } - appendLine("```") - appendLine() - } - } - } - } - - data class SearchResult( - val file: String, - val lineNumber: Int, - val matchedLine: String, - val context: List - ) - - companion object { - private val log = LoggerFactory.getLogger(SearchTask::class.java) - } -} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SecurityAuditTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SecurityAuditTask.kt index 63a83c93..def64593 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SecurityAuditTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/SecurityAuditTask.kt @@ -7,32 +7,32 @@ import com.simiacryptus.skyenet.apps.plan.file.SecurityAuditTask.SecurityAuditTa import org.slf4j.LoggerFactory class SecurityAuditTask( - planSettings: PlanSettings, - planTask: SecurityAuditTaskData? + planSettings: PlanSettings, + planTask: SecurityAuditTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class SecurityAuditTaskData( - @Description("List of files to be audited") - val filesToAudit: List? = null, - @Description("Specific areas of focus for the security audit") - val focusAreas: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.SecurityAudit.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) - - override val actorName: String = "SecurityAudit" - override val actorPrompt: String = """ + class SecurityAuditTaskData( + @Description("List of files to be audited") + val filesToAudit: List? = null, + @Description("Specific areas of focus for the security audit") + val focusAreas: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.SecurityAudit.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) + + override val actorName: String = "SecurityAudit" + override val actorPrompt: String = """ Perform a comprehensive security audit for the provided code files. Analyze the code for: 1. Potential security vulnerabilities 2. Insecure coding practices @@ -46,18 +46,18 @@ Format the response as a markdown document with appropriate headings and code sn Use diff format to show the proposed security fixes clearly. """.trimIndent() - override fun getAnalysisInstruction(): String = "Perform a security audit on the following code" + override fun getAnalysisInstruction(): String = "Perform a security audit on the following code" - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ SecurityAudit - Perform an automated security audit and provide suggestions for improving code security ** Specify the files to be audited ** Optionally provide specific areas of focus for the security audit """.trimMargin() - } + } - companion object { - private val log = LoggerFactory.getLogger(SecurityAuditTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(SecurityAuditTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/TestGenerationTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/TestGenerationTask.kt index 5d1c8c87..22780afd 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/TestGenerationTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/file/TestGenerationTask.kt @@ -7,31 +7,31 @@ import com.simiacryptus.skyenet.apps.plan.file.TestGenerationTask.TestGeneration import org.slf4j.LoggerFactory class TestGenerationTask( - planSettings: PlanSettings, - planTask: TestGenerationTaskData? + planSettings: PlanSettings, + planTask: TestGenerationTaskData? ) : AbstractAnalysisTask(planSettings, planTask) { - class TestGenerationTaskData( - @Description("List of files for which tests should be generated") - val filesToTest: List? = null, - @Description("List of input files or tasks to be examined when generating tests") - val inputReferences: List? = null, - task_description: String? = null, - task_dependencies: List? = null, - input_files: List? = null, - output_files: List? = null, - state: TaskState? = null - ) : FileTaskBase( - task_type = TaskType.TestGeneration.name, - task_description = task_description, - task_dependencies = task_dependencies, - input_files = input_files, - output_files = output_files, - state = state - ) + class TestGenerationTaskData( + @Description("List of files for which tests should be generated") + val filesToTest: List? = null, + @Description("List of input files or tasks to be examined when generating tests") + val inputReferences: List? = null, + task_description: String? = null, + task_dependencies: List? = null, + input_files: List? = null, + output_files: List? = null, + state: TaskState? = null + ) : FileTaskBase( + task_type = TaskType.TestGeneration.name, + task_description = task_description, + task_dependencies = task_dependencies, + input_files = input_files, + output_files = output_files, + state = state + ) - override val actorName: String = "TestGeneration" - override val actorPrompt: String = """ + override val actorName: String = "TestGeneration" + override val actorPrompt: String = """ Generate comprehensive unit tests for the provided code files. The tests should: |1. Cover all public methods and functions |2. Include both positive and negative test cases @@ -77,10 +77,10 @@ class TestGenerationTask( ${com.simiacryptus.skyenet.apps.plan.TRIPLE_TILDE} """.trimMargin() - override fun getAnalysisInstruction(): String = "Generate tests for the following code" + override fun getAnalysisInstruction(): String = "Generate tests for the following code" - override fun promptSegment(): String { - return """ + override fun promptSegment(): String { + return """ TestGeneration - Generate unit tests for the specified code files ** Specify the files for which tests should be generated using the 'filesToTest' field ** List input files/tasks to be examined when generating tests using the 'inputReferences' field @@ -89,9 +89,9 @@ class TestGenerationTask( ** Specify the files for which tests should be generated ** List input files/tasks to be examined when generating tests """.trimMargin() - } + } - companion object { - private val log = LoggerFactory.getLogger(TestGenerationTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(TestGenerationTask::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/EmbeddingSearchTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/EmbeddingSearchTask.kt index 9e841c87..0a0a287b 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/EmbeddingSearchTask.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/EmbeddingSearchTask.kt @@ -20,33 +20,33 @@ import java.util.regex.Pattern import kotlin.streams.asSequence class EmbeddingSearchTask( - planSettings: PlanSettings, - planTask: EmbeddingSearchTaskData? + planSettings: PlanSettings, + planTask: EmbeddingSearchTaskData? ) : AbstractTask(planSettings, planTask) { - class EmbeddingSearchTaskData( - @Description("The positive search queries to look for in the embeddings") - val positive_queries: List, - @Description("The negative search queries to avoid in the embeddings") - val negative_queries: List = emptyList(), - @Description("The distance type to use for comparing embeddings (Euclidean, Manhattan, or Cosine)") - val distance_type: DistanceType = DistanceType.Cosine, - @Description("The number of top results to return") - val count: Int = 5, - @Description("The minimum length of the content to be considered") - val min_length: Int = 0, - @Description("List of regex patterns that must be present in the content") - val required_regexes: List = emptyList(), - task_description: String? = null, - task_dependencies: List? = null, - state: TaskState? = null, - ) : PlanTaskBase( - task_type = "EmbeddingSearch", - task_description = task_description, - task_dependencies = task_dependencies, - state = state - ) + class EmbeddingSearchTaskData( + @Description("The positive search queries to look for in the embeddings") + val positive_queries: List, + @Description("The negative search queries to avoid in the embeddings") + val negative_queries: List = emptyList(), + @Description("The distance type to use for comparing embeddings (Euclidean, Manhattan, or Cosine)") + val distance_type: DistanceType = DistanceType.Cosine, + @Description("The number of top results to return") + val count: Int = 5, + @Description("The minimum length of the content to be considered") + val min_length: Int = 0, + @Description("List of regex patterns that must be present in the content") + val required_regexes: List = emptyList(), + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = "EmbeddingSearch", + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) - override fun promptSegment() = """ + override fun promptSegment() = """ EmbeddingSearch - Search for similar embeddings in index files and provide top results ** Specify the positive search queries ** Optionally specify negative search queries @@ -54,176 +54,176 @@ EmbeddingSearch - Search for similar embeddings in index files and provide top r ** Specify the number of top results to return """.trim() - override fun run( - agent: PlanCoordinator, - messages: List, - task: SessionTask, - api: ChatClient, - resultFn: (String) -> Unit, - api2: OpenAIClient, - planSettings: PlanSettings - ) { - val searchResults = performEmbeddingSearch(api2) - val formattedResults = formatSearchResults(searchResults) - task.add(MarkdownUtil.renderMarkdown(formattedResults, ui = agent.ui)) - resultFn(formattedResults) - } + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val searchResults = performEmbeddingSearch(api2) + val formattedResults = formatSearchResults(searchResults) + task.add(MarkdownUtil.renderMarkdown(formattedResults, ui = agent.ui)) + resultFn(formattedResults) + } - private fun performEmbeddingSearch(api: OpenAIClient): List { - val positiveEmbeddings = planTask?.positive_queries?.map { query -> - api.createEmbedding( - ApiModel.EmbeddingRequest( - input = query, - model = EmbeddingModels.Large.modelName - ) - ).data[0].embedding - } ?: emptyList() - val negativeEmbeddings = planTask?.negative_queries?.map { query -> - api.createEmbedding( - ApiModel.EmbeddingRequest( - input = query, - model = EmbeddingModels.Large.modelName - ) - ).data[0].embedding - } ?: emptyList() - if (positiveEmbeddings.isEmpty()) { - throw IllegalArgumentException("At least one positive query is required") - } - val distanceType = planTask?.distance_type ?: DistanceType.Cosine - val filtered = Files.walk(root).asSequence() - .filter { path -> - path.toString().endsWith(".index.data") - }.toList().toTypedArray() - val minLength = planTask?.min_length ?: 0 - val requiredRegexes = planTask?.required_regexes?.map { Pattern.compile(it) } ?: emptyList() - fun String.matchesAllRegexes(): Boolean { - return requiredRegexes.all { regex -> regex.matcher(this).find() } - } - - val searchResults = filtered - .flatMap { path -> - val records = DocumentRecord.readBinary(path.toString()) - records.mapNotNull { record -> - record.vector?.let { vector -> - val positiveDistances = positiveEmbeddings.filterNotNull().map { embedding -> - distanceType.distance(vector, embedding) - } - val negativeDistances = negativeEmbeddings.filterNotNull().map { embedding -> - distanceType.distance(vector, embedding) - } - val overallDistance = if (negativeDistances.isEmpty()) { - positiveDistances.minOrNull() ?: Double.MAX_VALUE - } else { - (positiveDistances.minOrNull() ?: Double.MAX_VALUE) / (negativeDistances.minOrNull() ?: Double.MIN_VALUE) - } - val content = record.text ?: "" - if (content.length >= minLength && content.matchesAllRegexes()) { - EmbeddingSearchResult( - file = root.relativize(path).toString(), - record = record, - distance = overallDistance - ) - } else null - } - } - } - .toList() - return searchResults - .sortedBy { it.distance } - .take(planTask?.count ?: 5) + private fun performEmbeddingSearch(api: OpenAIClient): List { + val positiveEmbeddings = planTask?.positive_queries?.map { query -> + api.createEmbedding( + ApiModel.EmbeddingRequest( + input = query, + model = EmbeddingModels.Large.modelName + ) + ).data[0].embedding + } ?: emptyList() + val negativeEmbeddings = planTask?.negative_queries?.map { query -> + api.createEmbedding( + ApiModel.EmbeddingRequest( + input = query, + model = EmbeddingModels.Large.modelName + ) + ).data[0].embedding + } ?: emptyList() + if (positiveEmbeddings.isEmpty()) { + throw IllegalArgumentException("At least one positive query is required") + } + val distanceType = planTask?.distance_type ?: DistanceType.Cosine + val filtered = Files.walk(root).asSequence() + .filter { path -> + path.toString().endsWith(".index.data") + }.toList().toTypedArray() + val minLength = planTask?.min_length ?: 0 + val requiredRegexes = planTask?.required_regexes?.map { Pattern.compile(it) } ?: emptyList() + fun String.matchesAllRegexes(): Boolean { + return requiredRegexes.all { regex -> regex.matcher(this).find() } } - private fun formatSearchResults(results: List): String { - return buildString { - appendLine("# Embedding Search Results") - appendLine() - results.forEachIndexed { index, result -> - appendLine("## Result ${index + 1}") - appendLine("* Distance: %.3f".format(result.distance)) - appendLine("* File: ${result.record.sourcePath}") - appendLine(getContextSummary(result.record.sourcePath, result.record.jsonPath)) - appendLine("Metadata:\n```json\n${result.record.metadata}\n```") - appendLine() + val searchResults = filtered + .flatMap { path -> + val records = DocumentRecord.readBinary(path.toString()) + records.mapNotNull { record -> + record.vector?.let { vector -> + val positiveDistances = positiveEmbeddings.filterNotNull().map { embedding -> + distanceType.distance(vector, embedding) + } + val negativeDistances = negativeEmbeddings.filterNotNull().map { embedding -> + distanceType.distance(vector, embedding) } + val overallDistance = if (negativeDistances.isEmpty()) { + positiveDistances.minOrNull() ?: Double.MAX_VALUE + } else { + (positiveDistances.minOrNull() ?: Double.MAX_VALUE) / (negativeDistances.minOrNull() ?: Double.MIN_VALUE) + } + val content = record.text ?: "" + if (content.length >= minLength && content.matchesAllRegexes()) { + EmbeddingSearchResult( + file = root.relativize(path).toString(), + record = record, + distance = overallDistance + ) + } else null + } } + } + .toList() + return searchResults + .sortedBy { it.distance } + .take(planTask?.count ?: 5) + } + + private fun formatSearchResults(results: List): String { + return buildString { + appendLine("# Embedding Search Results") + appendLine() + results.forEachIndexed { index, result -> + appendLine("## Result ${index + 1}") + appendLine("* Distance: %.3f".format(result.distance)) + appendLine("* File: ${result.record.sourcePath}") + appendLine(getContextSummary(result.record.sourcePath, result.record.jsonPath)) + appendLine("Metadata:\n```json\n${result.record.metadata}\n```") + appendLine() + } } + } - private fun getContextSummary(sourcePath: String, jsonPath: String): String { - val objectMapper = ObjectMapper() - val jsonNode = objectMapper.readTree(File(sourcePath)) - val contextNode = getNodeAtPath(jsonNode, jsonPath) - return buildString { - appendLine("```json") - appendLine(summarizeContext(contextNode, jsonPath, jsonNode)) - appendLine("```") - } + private fun getContextSummary(sourcePath: String, jsonPath: String): String { + val objectMapper = ObjectMapper() + val jsonNode = objectMapper.readTree(File(sourcePath)) + val contextNode = getNodeAtPath(jsonNode, jsonPath) + return buildString { + appendLine("```json") + appendLine(summarizeContext(contextNode, jsonPath, jsonNode)) + appendLine("```") } + } - private fun getNodeAtPath(jsonNode: JsonNode, path: String): JsonNode { - var currentNode = jsonNode - path.split(".").forEach { segment -> - currentNode = when { - segment.contains("[") -> { - val (arrayName, indexPart) = segment.split("[", limit = 2) - val index = indexPart.substringBefore("]").toInt() - val field = currentNode.get(arrayName) - val child = field?.get(index) - if (child == null) { - return currentNode - } - child - } + private fun getNodeAtPath(jsonNode: JsonNode, path: String): JsonNode { + var currentNode = jsonNode + path.split(".").forEach { segment -> + currentNode = when { + segment.contains("[") -> { + val (arrayName, indexPart) = segment.split("[", limit = 2) + val index = indexPart.substringBefore("]").toInt() + val field = currentNode.get(arrayName) + val child = field?.get(index) + if (child == null) { + return currentNode + } + child + } - else -> { - val child = currentNode.get(segment) - if (child == null) { - return currentNode - } - child - } - } + else -> { + val child = currentNode.get(segment) + if (child == null) { + return currentNode + } + child } - return currentNode + } } + return currentNode + } - private fun summarizeContext(node: JsonNode, path: String, jsonNode: JsonNode): String { - var summary = mutableMapOf() - // Add siblings and descendants - node.fields().forEach { (key, value) -> - if (value.isPrimitive()) { - summary[key] = value.asText() - } - } - // Add siblings of parent nodes - val pathSegments = path.split(".") - for (i in pathSegments.size - 1 downTo 1) { - val parentPath = pathSegments.subList(0, i).joinToString(".") - val parentNode = getNodeAtPath(jsonNode, parentPath) - summary = mutableMapOf( - pathSegments[i] to summary - ) - parentNode.fields().forEach { (key, value) -> - when { - value.isPrimitive() -> summary[key] = value.asText() - key == "entities" || key == "tags" || key == "metadata" -> summary[key] = value - } - } + private fun summarizeContext(node: JsonNode, path: String, jsonNode: JsonNode): String { + var summary = mutableMapOf() + // Add siblings and descendants + node.fields().forEach { (key, value) -> + if (value.isPrimitive()) { + summary[key] = value.asText() + } + } + // Add siblings of parent nodes + val pathSegments = path.split(".") + for (i in pathSegments.size - 1 downTo 1) { + val parentPath = pathSegments.subList(0, i).joinToString(".") + val parentNode = getNodeAtPath(jsonNode, parentPath) + summary = mutableMapOf( + pathSegments[i] to summary + ) + parentNode.fields().forEach { (key, value) -> + when { + value.isPrimitive() -> summary[key] = value.asText() + key == "entities" || key == "tags" || key == "metadata" -> summary[key] = value } - return JsonUtil.toJson(summary) + } } + return JsonUtil.toJson(summary) + } - data class EmbeddingSearchResult( - val file: String, - val record: DocumentRecord, - val distance: Double - ) + data class EmbeddingSearchResult( + val file: String, + val record: DocumentRecord, + val distance: Double + ) - companion object { - private val log = LoggerFactory.getLogger(EmbeddingSearchTask::class.java) - } + companion object { + private val log = LoggerFactory.getLogger(EmbeddingSearchTask::class.java) + } } private fun JsonNode.isPrimitive(): Boolean { - return this.isNumber || this.isTextual || this.isBoolean + return this.isNumber || this.isTextual || this.isBoolean } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/KnowledgeIndexingTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/KnowledgeIndexingTask.kt new file mode 100644 index 00000000..51ff4f66 --- /dev/null +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/KnowledgeIndexingTask.kt @@ -0,0 +1,110 @@ +package com.simiacryptus.skyenet.apps.plan.knowledge + +import com.simiacryptus.jopenai.ChatClient +import com.simiacryptus.jopenai.OpenAIClient +import com.simiacryptus.jopenai.describe.Description +import com.simiacryptus.jopenai.models.chatModel +import com.simiacryptus.skyenet.apps.parse.CodeParsingModel +import com.simiacryptus.skyenet.apps.parse.DocumentParserApp +import com.simiacryptus.skyenet.apps.parse.DocumentParsingModel +import com.simiacryptus.skyenet.apps.parse.DocumentRecord.Companion.saveAsBinary +import com.simiacryptus.skyenet.apps.parse.ProgressState +import com.simiacryptus.skyenet.apps.plan.* +import com.simiacryptus.skyenet.util.MarkdownUtil +import com.simiacryptus.skyenet.webui.session.SessionTask +import org.slf4j.LoggerFactory +import java.io.File +import java.nio.file.Files +import java.util.concurrent.Executors + +class KnowledgeIndexingTask( + planSettings: PlanSettings, + planTask: KnowledgeIndexingTaskData? +) : AbstractTask(planSettings, planTask) { + + class KnowledgeIndexingTaskData( + @Description("The file paths to process and index") + val file_paths: List, + @Description("The type of parsing to use (document, code)") + val parsing_type: String = "document", + @Description("The chunk size for parsing (default 0.1)") + val chunk_size: Double = 0.1, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = TaskType.KnowledgeIndexing.name, + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) + + override fun promptSegment() = """ + KnowledgeIndexing - Process and index files for semantic search + ** Specify the file paths to process + ** Specify the parsing type (document or code) + ** Optionally specify the chunk size (default 0.1) + """.trimIndent() + + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val filePaths = planTask?.file_paths ?: return + val files = filePaths.map { File(it) }.filter { it.exists() } + + if (files.isEmpty()) { + val result = "No valid files found to process" + task.add(MarkdownUtil.renderMarkdown(result, ui = agent.ui)) + resultFn(result) + return + } + + val threadPool = Executors.newFixedThreadPool(8) + try { + val parsingModel = when (planTask.parsing_type.lowercase()) { + "code" -> CodeParsingModel(planSettings.defaultModel, planTask.chunk_size) + else -> DocumentParsingModel(planSettings.defaultModel, planTask.chunk_size) + } + + val progressState = ProgressState() + var currentProgress = 0.0 + progressState.onUpdate += { + val newProgress = it.progress / it.max + if (newProgress != currentProgress) { + currentProgress = newProgress + task.add(MarkdownUtil.renderMarkdown("Processing: ${(currentProgress * 100).toInt()}%", ui = agent.ui)) + } + } + + saveAsBinary( + openAIClient = api2, + pool = threadPool, + progressState = progressState, + inputPaths = files.map { it.absolutePath }.toTypedArray() + ) + + val result = buildString { + appendLine("# Knowledge Indexing Complete") + appendLine() + appendLine("Processed ${files.size} files:") + files.forEach { file -> + appendLine("* ${file.name}") + } + } + task.add(MarkdownUtil.renderMarkdown(result, ui = agent.ui)) + resultFn(result) + } finally { + threadPool.shutdown() + } + } + + companion object { + private val log = LoggerFactory.getLogger(KnowledgeIndexingTask::class.java) + } +} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/WebSearchAndIndexTask.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/WebSearchAndIndexTask.kt new file mode 100644 index 00000000..a14cf4d4 --- /dev/null +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/apps/plan/knowledge/WebSearchAndIndexTask.kt @@ -0,0 +1,155 @@ +package com.simiacryptus.skyenet.apps.plan.knowledge + +import com.simiacryptus.jopenai.ChatClient +import com.simiacryptus.jopenai.OpenAIClient +import com.simiacryptus.jopenai.describe.Description +import com.simiacryptus.skyenet.apps.parse.DocumentRecord +import com.simiacryptus.skyenet.apps.plan.AbstractTask +import com.simiacryptus.skyenet.apps.plan.PlanCoordinator +import com.simiacryptus.skyenet.apps.plan.PlanSettings +import com.simiacryptus.skyenet.apps.plan.PlanTaskBase +import com.simiacryptus.skyenet.webui.session.SessionTask +import com.simiacryptus.util.JsonUtil +import org.apache.commons.io.FileUtils +import org.slf4j.LoggerFactory +import java.io.File +import java.net.URI +import java.net.URLEncoder +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.util.concurrent.Executors +import kotlin.collections.forEachIndexed +import kotlin.collections.map +import kotlin.collections.mapNotNull +import kotlin.jvm.java +import kotlin.text.appendLine +import kotlin.text.replace +import kotlin.text.take +import kotlin.text.trimMargin + +class WebSearchAndIndexTask( + planSettings: PlanSettings, + planTask: WebSearchAndIndexTaskData? +) : AbstractTask(planSettings, planTask) { + + class WebSearchAndIndexTaskData( + @Description("The search query to use for web search") + val search_query: String, + @Description("The number of search results to process (max 10)") + val num_results: Int = 5, + @Description("The directory to store downloaded and indexed content") + val output_directory: String, + task_description: String? = null, + task_dependencies: List? = null, + state: TaskState? = null, + ) : PlanTaskBase( + task_type = "WebSearchAndIndex", + task_description = task_description, + task_dependencies = task_dependencies, + state = state + ) + + override fun promptSegment() = """ + WebSearchAndIndex - Search web, download content, parse and index for future embedding search + ** Specify the search query + ** Specify number of results to process (max 10) + ** Specify output directory for indexed content + """.trimMargin() + + override fun run( + agent: PlanCoordinator, + messages: List, + task: SessionTask, + api: ChatClient, + resultFn: (String) -> Unit, + api2: OpenAIClient, + planSettings: PlanSettings + ) { + val searchResults = performGoogleSearch(planSettings) + val downloadedFiles = downloadAndSaveContent(searchResults) + val indexedFiles = indexContent(downloadedFiles, api2) + + val summary = buildString { + appendLine("# Web Search and Index Results") + appendLine("## Search Query: ${planTask?.search_query}") + appendLine("## Downloaded and Indexed Files:") + indexedFiles.forEachIndexed { index, file -> + appendLine("${index + 1}. ${file.name}") + } + } + + resultFn(summary) + } + + private fun performGoogleSearch(planSettings: PlanSettings): List> { + val client = HttpClient.newBuilder().build() + val encodedQuery = URLEncoder.encode(planTask?.search_query, "UTF-8") + val uriBuilder = + "https://www.googleapis.com/customsearch/v1?key=${planSettings.googleApiKey}&cx=${planSettings.googleSearchEngineId}&q=$encodedQuery&num=${planTask?.num_results}" + + val request = HttpRequest.newBuilder() + .uri(URI.create(uriBuilder)) + .GET() + .build() + + val response = client.send(request, HttpResponse.BodyHandlers.ofString()) + if (response.statusCode() != 200) { + throw RuntimeException("Google API request failed with status ${response.statusCode()}: ${response.body()}") + } + + val searchResults: Map = JsonUtil.fromJson(response.body(), Map::class.java) + return (searchResults["items"] as List>?) ?: emptyList() + } + + private fun downloadAndSaveContent(searchResults: List>): List { + val outputDir = File(planTask?.output_directory ?: "web_content") + outputDir.mkdirs() + + val client = HttpClient.newBuilder().build() + return searchResults.mapNotNull { result -> + try { + val url = result["link"] as String + val title = result["title"] as String + val fileName = "${sanitizeFileName(title)}.html" + val outputFile = File(outputDir, fileName) + + val request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .GET() + .build() + + val response = client.send(request, HttpResponse.BodyHandlers.ofString()) + if (response.statusCode() == 200) { + FileUtils.writeStringToFile(outputFile, response.body(), "UTF-8") + outputFile + } else null + } catch (e: Exception) { + log.error("Error downloading content", e) + null + } + } + } + + private fun indexContent(files: List, api: OpenAIClient): List { + val threadPool = Executors.newFixedThreadPool(8) + try { + return DocumentRecord.saveAsBinary( + openAIClient = api, + pool = threadPool, + progressState = null, + inputPaths = files.map { it.absolutePath }.toTypedArray() + ).map { File(it) }.toList() + } finally { + threadPool.shutdown() + } + } + + private fun sanitizeFileName(fileName: String): String { + return fileName.replace(Regex("[^a-zA-Z0-9.-]"), "_").take(50) + } + + companion object { + private val log = LoggerFactory.getLogger(WebSearchAndIndexTask::class.java) + } +} \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt index 580b1279..09138ef9 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/interpreter/ProcessInterpreter.kt @@ -3,48 +3,48 @@ package com.simiacryptus.skyenet.interpreter import java.util.concurrent.TimeUnit open class ProcessInterpreter( - private val defs: Map = mapOf(), + private val defs: Map = mapOf(), ) : Interpreter { - val command: List - get() = defs["command"]?.let { command -> - when (command) { - is String -> command.split(" ") - is List<*> -> command.map { it.toString() } - else -> throw IllegalArgumentException("Invalid command: $command") - } - } ?: listOf("bash") + val command: List + get() = defs["command"]?.let { command -> + when (command) { + is String -> command.split(" ") + is List<*> -> command.map { it.toString() } + else -> throw IllegalArgumentException("Invalid command: $command") + } + } ?: listOf("bash") - final override fun getLanguage(): String = defs["language"]?.toString() ?: "bash" - override fun getSymbols() = defs + final override fun getLanguage(): String = defs["language"]?.toString() ?: "bash" + override fun getSymbols() = defs - override fun validate(code: String): Throwable? { - // Always valid - return null - } + override fun validate(code: String): Throwable? { + // Always valid + return null + } - override fun run(code: String): Any? { - val wrappedCode = wrapCode(code.trim()) - val cmd = command.toTypedArray() - val cwd = defs["workingDir"]?.toString()?.let { java.io.File(it) } ?: java.io.File(".") - val processBuilder = ProcessBuilder(*cmd).directory(cwd) - defs["env"]?.let { env -> processBuilder.environment().putAll((env as Map)) } - val process = processBuilder.start() + override fun run(code: String): Any? { + val wrappedCode = wrapCode(code.trim()) + val cmd = command.toTypedArray() + val cwd = defs["workingDir"]?.toString()?.let { java.io.File(it) } ?: java.io.File(".") + val processBuilder = ProcessBuilder(*cmd).directory(cwd) + defs["env"]?.let { env -> processBuilder.environment().putAll((env as Map)) } + val process = processBuilder.start() - process.outputStream.write(wrappedCode.toByteArray()) - process.outputStream.close() - val output = process.inputStream.bufferedReader().readText() - val error = process.errorStream.bufferedReader().readText() + process.outputStream.write(wrappedCode.toByteArray()) + process.outputStream.close() + val output = process.inputStream.bufferedReader().readText() + val error = process.errorStream.bufferedReader().readText() - val waitFor = process.waitFor(5, TimeUnit.MINUTES) - if (!waitFor) { - process.destroy() - throw RuntimeException("Timeout; output: $output; error: $error") - } else if (error.isNotEmpty()) { - //throw RuntimeException(error) - return ( - """ + val waitFor = process.waitFor(5, TimeUnit.MINUTES) + if (!waitFor) { + process.destroy() + throw RuntimeException("Timeout; output: $output; error: $error") + } else if (error.isNotEmpty()) { + //throw RuntimeException(error) + return ( + """ |ERROR: |```text |$error @@ -55,11 +55,11 @@ open class ProcessInterpreter( |$output |``` """.trimMargin() - ) - } else { - return output - } + ) + } else { + return output } + } - companion object + companion object } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/EncryptFiles.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/EncryptFiles.kt index f133daf2..2fe7ba0f 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/EncryptFiles.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/EncryptFiles.kt @@ -7,18 +7,18 @@ import java.nio.file.Paths object EncryptFiles { - @JvmStatic - fun main(args: Array) { - File("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json""") - .readText().encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") - .write("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json.kms""") - } + @JvmStatic + fun main(args: Array) { + File("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json""") + .readText().encrypt("arn:aws:kms:us-east-1:470240306861:key/a1340b89-64e6-480c-a44c-e7bc0c70dcb1") + .write("""C:\Users\andre\code\SkyeNet\webui\src\test\resources\client_secret_google_oauth.json.kms""") + } } fun String.write(outpath: String) { - Files.write(Paths.get(outpath), toByteArray()) + Files.write(Paths.get(outpath), toByteArray()) } fun String.encrypt(keyId: String) = ApplicationServices.cloud?.encrypt(encodeToByteArray(), keyId) - ?: throw RuntimeException("Unable to encrypt data") \ No newline at end of file + ?: throw RuntimeException("Unable to encrypt data") \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/MarkdownUtil.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/MarkdownUtil.kt index e5127f8b..7729088a 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/MarkdownUtil.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/MarkdownUtil.kt @@ -11,41 +11,41 @@ import java.nio.file.Files import java.util.* object MarkdownUtil { - fun renderMarkdown( - markdown: String, - options: MutableDataSet = defaultOptions(), - tabs: Boolean = true, - ui: ApplicationInterface? = null, - ): String { - if (markdown.isBlank()) return "" - val parser = Parser.builder(options).build() - val renderer = HtmlRenderer.builder(options).build() - val document = parser.parse(markdown) - val html = renderer.render(document) - val mermaidRegex = - Regex("]*>(.*?)", RegexOption.DOT_MATCHES_ALL) - val matches = mermaidRegex.findAll(html) - var htmlContent = html - matches.forEach { match -> - var mermaidCode = match.groups[1]!!.value - // HTML Decode mermaidCode - val fixedMermaidCode = fixupMermaidCode(mermaidCode) - var mermaidDiagramHTML = """
$fixedMermaidCode
""" - try { - if (true) { - val svg = renderMermaidToSVG(fixedMermaidCode) - if (null != ui) { - val newTask = ui.newTask(false) - newTask.complete(svg) - mermaidDiagramHTML = newTask.placeholder - } else { - mermaidDiagramHTML = svg - } - } - } catch (e: Exception) { - log.warn("Failed to render Mermaid diagram", e) - } - val replacement = if (tabs) """ + fun renderMarkdown( + markdown: String, + options: MutableDataSet = defaultOptions(), + tabs: Boolean = true, + ui: ApplicationInterface? = null, + ): String { + if (markdown.isBlank()) return "" + val parser = Parser.builder(options).build() + val renderer = HtmlRenderer.builder(options).build() + val document = parser.parse(markdown) + val html = renderer.render(document) + val mermaidRegex = + Regex("]*>(.*?)", RegexOption.DOT_MATCHES_ALL) + val matches = mermaidRegex.findAll(html) + var htmlContent = html + matches.forEach { match -> + var mermaidCode = match.groups[1]!!.value + // HTML Decode mermaidCode + val fixedMermaidCode = fixupMermaidCode(mermaidCode) + var mermaidDiagramHTML = """
$fixedMermaidCode
""" + try { + if (true) { + val svg = renderMermaidToSVG(fixedMermaidCode) + if (null != ui) { + val newTask = ui.newTask(false) + newTask.complete(svg) + mermaidDiagramHTML = newTask.placeholder + } else { + mermaidDiagramHTML = svg + } + } + } catch (e: Exception) { + log.warn("Failed to render Mermaid diagram", e) + } + val replacement = if (tabs) """ |
|
| @@ -57,145 +57,145 @@ object MarkdownUtil { |""".trimMargin() else """ |$mermaidDiagramHTML |""".trimMargin() - htmlContent = htmlContent.replace(match.value, replacement) - } - //language=HTML - return if (tabs) { - displayMapInTabs( - mapOf( - "HTML" to htmlContent, - "Markdown" to """
${
-                        markdown.replace(Regex("<"), "<").replace(Regex(">"), ">")
-                    }
""", - "Hide" to "", - ), ui = ui - ) - } else htmlContent + htmlContent = htmlContent.replace(match.value, replacement) } + //language=HTML + return if (tabs) { + displayMapInTabs( + mapOf( + "HTML" to htmlContent, + "Markdown" to """
${
+            markdown.replace(Regex("<"), "<").replace(Regex(">"), ">")
+          }
""", + "Hide" to "", + ), ui = ui + ) + } else htmlContent + } - var MMDC_CMD: List = listOf("mmdc") - private fun renderMermaidToSVG(mermaidCode: String): String { - // mmdc -i input.mmd -o output.svg - val tempInputFile = Files.createTempFile("mermaid", ".mmd").toFile() - val tempOutputFile = Files.createTempFile("mermaid", ".svg").toFile() - tempInputFile.writeText(StringEscapeUtils.unescapeHtml4(mermaidCode)) - val strings = MMDC_CMD + listOf("-i", tempInputFile.absolutePath, "-o", tempOutputFile.absolutePath) - val processBuilder = - ProcessBuilder(*strings.toTypedArray()) - processBuilder.redirectErrorStream(true) - val process = processBuilder.start() - val output = StringBuilder() - val errorOutput = StringBuilder() - process.inputStream.bufferedReader().use { - it.lines().forEach { line -> output.append(line) } - } - process.errorStream.bufferedReader().use { - it.lines().forEach { line -> errorOutput.append(line) } - } - process.waitFor() - val svgContent = tempOutputFile.readText() - tempInputFile.delete() - tempOutputFile.delete() - if (output.isNotEmpty()) { - log.error("Mermaid CLI Output: $output") - } - if (errorOutput.isNotEmpty()) { - log.error("Mermaid CLI Error: $errorOutput") - } - if (svgContent.isNullOrBlank()) { - throw RuntimeException("Mermaid CLI failed to generate SVG") - } - return svgContent + var MMDC_CMD: List = listOf("mmdc") + private fun renderMermaidToSVG(mermaidCode: String): String { + // mmdc -i input.mmd -o output.svg + val tempInputFile = Files.createTempFile("mermaid", ".mmd").toFile() + val tempOutputFile = Files.createTempFile("mermaid", ".svg").toFile() + tempInputFile.writeText(StringEscapeUtils.unescapeHtml4(mermaidCode)) + val strings = MMDC_CMD + listOf("-i", tempInputFile.absolutePath, "-o", tempOutputFile.absolutePath) + val processBuilder = + ProcessBuilder(*strings.toTypedArray()) + processBuilder.redirectErrorStream(true) + val process = processBuilder.start() + val output = StringBuilder() + val errorOutput = StringBuilder() + process.inputStream.bufferedReader().use { + it.lines().forEach { line -> output.append(line) } } - - // Simplified parsing states - enum class State { - DEFAULT, IN_NODE, IN_EDGE, IN_LABEL, IN_KEYWORD + process.errorStream.bufferedReader().use { + it.lines().forEach { line -> errorOutput.append(line) } + } + process.waitFor() + val svgContent = tempOutputFile.readText() + tempInputFile.delete() + tempOutputFile.delete() + if (output.isNotEmpty()) { + log.error("Mermaid CLI Output: $output") } + if (errorOutput.isNotEmpty()) { + log.error("Mermaid CLI Error: $errorOutput") + } + if (svgContent.isNullOrBlank()) { + throw RuntimeException("Mermaid CLI failed to generate SVG") + } + return svgContent + } - fun fixupMermaidCode(code: String): String { - val stringBuilder = StringBuilder() - var index = 0 + // Simplified parsing states + enum class State { + DEFAULT, IN_NODE, IN_EDGE, IN_LABEL, IN_KEYWORD + } + fun fixupMermaidCode(code: String): String { + val stringBuilder = StringBuilder() + var index = 0 - var currentState = State.DEFAULT - var labelStart = -1 - val keywords = listOf("graph", "subgraph", "end", "classDef", "class", "click", "style") - while (index < code.length) { - when (currentState) { - State.DEFAULT -> { - if (code.startsWith(keywords.find { code.startsWith(it, index) } ?: "", index)) { - // Start of a keyword - currentState = State.IN_KEYWORD - stringBuilder.append(code[index]) - } else - if (code[index] == '[' || code[index] == '(' || code[index] == '{') { - // Possible start of a label - currentState = State.IN_LABEL - labelStart = index - } else if (code[index].isWhitespace() || code[index] == '-') { - // Continue in default state, possibly an edge - stringBuilder.append(code[index]) - } else { - // Start of a node - currentState = State.IN_NODE - stringBuilder.append(code[index]) - } - } + var currentState = State.DEFAULT + var labelStart = -1 + val keywords = listOf("graph", "subgraph", "end", "classDef", "class", "click", "style") - State.IN_KEYWORD -> { - if (code[index].isWhitespace()) { - // End of a keyword - currentState = State.DEFAULT - } - stringBuilder.append(code[index]) - } + while (index < code.length) { + when (currentState) { + State.DEFAULT -> { + if (code.startsWith(keywords.find { code.startsWith(it, index) } ?: "", index)) { + // Start of a keyword + currentState = State.IN_KEYWORD + stringBuilder.append(code[index]) + } else + if (code[index] == '[' || code[index] == '(' || code[index] == '{') { + // Possible start of a label + currentState = State.IN_LABEL + labelStart = index + } else if (code[index].isWhitespace() || code[index] == '-') { + // Continue in default state, possibly an edge + stringBuilder.append(code[index]) + } else { + // Start of a node + currentState = State.IN_NODE + stringBuilder.append(code[index]) + } + } - State.IN_NODE -> { - if (code[index] == '-' || code[index] == '>' || code[index].isWhitespace()) { - // End of a node, start of an edge or space - currentState = if (code[index].isWhitespace()) State.DEFAULT else State.IN_EDGE - stringBuilder.append(code[index]) - } else { - // Continue in node - stringBuilder.append(code[index]) - } - } + State.IN_KEYWORD -> { + if (code[index].isWhitespace()) { + // End of a keyword + currentState = State.DEFAULT + } + stringBuilder.append(code[index]) + } - State.IN_EDGE -> { - if (!code[index].isWhitespace() && code[index] != '-' && code[index] != '>') { - // End of an edge, start of a node - currentState = State.IN_NODE - stringBuilder.append(code[index]) - } else { - // Continue in edge - stringBuilder.append(code[index]) - } - } + State.IN_NODE -> { + if (code[index] == '-' || code[index] == '>' || code[index].isWhitespace()) { + // End of a node, start of an edge or space + currentState = if (code[index].isWhitespace()) State.DEFAULT else State.IN_EDGE + stringBuilder.append(code[index]) + } else { + // Continue in node + stringBuilder.append(code[index]) + } + } - State.IN_LABEL -> { - if (code[index] == ']' || code[index] == ')' || code[index] == '}') { - // End of a label - val label = code.substring(labelStart + 1, index) - val escapedLabel = "\"${label.replace("\"", "'")}\"" - stringBuilder.append(escapedLabel) - stringBuilder.append(code[index]) - currentState = State.DEFAULT - } - } - } - index++ + State.IN_EDGE -> { + if (!code[index].isWhitespace() && code[index] != '-' && code[index] != '>') { + // End of an edge, start of a node + currentState = State.IN_NODE + stringBuilder.append(code[index]) + } else { + // Continue in edge + stringBuilder.append(code[index]) + } } - return stringBuilder.toString() + State.IN_LABEL -> { + if (code[index] == ']' || code[index] == ')' || code[index] == '}') { + // End of a label + val label = code.substring(labelStart + 1, index) + val escapedLabel = "\"${label.replace("\"", "'")}\"" + stringBuilder.append(escapedLabel) + stringBuilder.append(code[index]) + currentState = State.DEFAULT + } + } + } + index++ } - private fun defaultOptions(): MutableDataSet { - val options = MutableDataSet() - options.set(Parser.EXTENSIONS, listOf(TablesExtension.create())) - return options - } + return stringBuilder.toString() + } + + private fun defaultOptions(): MutableDataSet { + val options = MutableDataSet() + options.set(Parser.EXTENSIONS, listOf(TablesExtension.create())) + return options + } - private val log = org.slf4j.LoggerFactory.getLogger(MarkdownUtil::class.java) + private val log = org.slf4j.LoggerFactory.getLogger(MarkdownUtil::class.java) } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/OpenAPI.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/OpenAPI.kt index e6d2759b..368e6c7d 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/OpenAPI.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/OpenAPI.kt @@ -4,100 +4,100 @@ import com.fasterxml.jackson.annotation.JsonInclude // OpenAPI root document data class OpenAPI( - val openapi: String = "3.0.0", - val info: Info? = null, - val paths: Map? = emptyMap(), - @JsonInclude(JsonInclude.Include.NON_NULL) - val components: Components? = null + val openapi: String = "3.0.0", + val info: Info? = null, + val paths: Map? = emptyMap(), + @JsonInclude(JsonInclude.Include.NON_NULL) + val components: Components? = null ) // Metadata about the API data class Info( - val title: String? = null, - val version: String? = null, - val description: String? = null, - val termsOfService: String? = null, - val contact: Contact? = null, - val license: License? = null + val title: String? = null, + val version: String? = null, + val description: String? = null, + val termsOfService: String? = null, + val contact: Contact? = null, + val license: License? = null ) // Contact information data class Contact( - val name: String? = null, - val url: String? = null, - val email: String? = null + val name: String? = null, + val url: String? = null, + val email: String? = null ) // License information data class License( - val name: String? = null, - val url: String? = null + val name: String? = null, + val url: String? = null ) // Paths and operations data class PathItem( - val get: Operation? = null, - val put: Operation? = null, - val post: Operation? = null, - val delete: Operation? = null, - val options: Operation? = null, - val head: Operation? = null, - val patch: Operation? = null + val get: Operation? = null, + val put: Operation? = null, + val post: Operation? = null, + val delete: Operation? = null, + val options: Operation? = null, + val head: Operation? = null, + val patch: Operation? = null ) // An API operation data class Operation( - val summary: String? = null, - val description: String? = null, - val responses: Map? = emptyMap(), - val parameters: List? = emptyList(), - val operationId: String? = null, - val requestBody: RequestBody? = null, - val security: List>>? = emptyList(), - val tags: List? = emptyList(), - val callbacks: Map? = emptyMap(), - val deprecated: Boolean? = null, + val summary: String? = null, + val description: String? = null, + val responses: Map? = emptyMap(), + val parameters: List? = emptyList(), + val operationId: String? = null, + val requestBody: RequestBody? = null, + val security: List>>? = emptyList(), + val tags: List? = emptyList(), + val callbacks: Map? = emptyMap(), + val deprecated: Boolean? = null, ) // Operation response data class Response( - val description: String? = null, - @JsonInclude(JsonInclude.Include.NON_NULL) - val content: Map? = emptyMap() + val description: String? = null, + @JsonInclude(JsonInclude.Include.NON_NULL) + val content: Map? = emptyMap() ) // Components for reusable objects data class Components( - val schemas: Map? = emptyMap(), - val responses: Map? = emptyMap(), - val parameters: Map? = emptyMap(), - val examples: Map? = emptyMap(), - val requestBodies: Map? = emptyMap(), - val headers: Map? = emptyMap(), - val securitySchemes: Map? = emptyMap(), - val links: Map? = emptyMap(), - val callbacks: Map? = emptyMap() + val schemas: Map? = emptyMap(), + val responses: Map? = emptyMap(), + val parameters: Map? = emptyMap(), + val examples: Map? = emptyMap(), + val requestBodies: Map? = emptyMap(), + val headers: Map? = emptyMap(), + val securitySchemes: Map? = emptyMap(), + val links: Map? = emptyMap(), + val callbacks: Map? = emptyMap() ) // Simplified examples of component objects data class Schema( - val type: String? = null, - val properties: Map? = emptyMap(), - val items: Schema? = null, - val `$ref`: String? = null, - val format: String? = null, - val description: String? = null, + val type: String? = null, + val properties: Map? = emptyMap(), + val items: Schema? = null, + val `$ref`: String? = null, + val format: String? = null, + val description: String? = null, - ) + ) data class Parameter( - val name: String? = null, - val `in`: String? = null, - val description: String? = null, - val required: Boolean? = null, - val schema: Schema? = null, - val content: Map? = null, - val example: Any? = null, + val name: String? = null, + val `in`: String? = null, + val description: String? = null, + val required: Boolean? = null, + val schema: Schema? = null, + val content: Map? = null, + val example: Any? = null, ) data class Example(val summary: String? = null, val description: String? = null) diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/Selenium2S3.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/Selenium2S3.kt index d8d576f1..7fb30739 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/Selenium2S3.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/Selenium2S3.kt @@ -28,436 +28,436 @@ import java.util.concurrent.Semaphore import java.util.concurrent.ThreadPoolExecutor open class Selenium2S3( - val pool: ThreadPoolExecutor = Executors.newCachedThreadPool() as ThreadPoolExecutor, - private val cookies: Array?, + val pool: ThreadPoolExecutor = Executors.newCachedThreadPool() as ThreadPoolExecutor, + private val cookies: Array?, ) : Selenium { - var loadImages: Boolean = false - open val driver: WebDriver by lazy { - chromeDriver(loadImages = loadImages).apply { - setCookies( - this, - cookies - ) - } + var loadImages: Boolean = false + open val driver: WebDriver by lazy { + chromeDriver(loadImages = loadImages).apply { + setCookies( + this, + cookies + ) } - - private val httpClient by lazy { - HttpAsyncClientBuilder.create() - .useSystemProperties() - .setDefaultCookieStore(BasicCookieStore().apply { - cookies?.forEach { cookie -> addCookie(BasicClientCookie(cookie.name, cookie.value)) } - }) - .setThreadFactory(pool.threadFactory) - .build() - .also { it.start() } + } + + private val httpClient by lazy { + HttpAsyncClientBuilder.create() + .useSystemProperties() + .setDefaultCookieStore(BasicCookieStore().apply { + cookies?.forEach { cookie -> addCookie(BasicClientCookie(cookie.name, cookie.value)) } + }) + .setThreadFactory(pool.threadFactory) + .build() + .also { it.start() } + } + + private val linkReplacements = mutableMapOf() + private val htmlPages: MutableMap = mutableMapOf() + private val jsonPages = mutableMapOf() + private val links: MutableList = mutableListOf() + + override fun save( + url: URL, + currentFilename: String?, + saveRoot: String + ) { + log.info("Saving URL: $url") + log.info("Current filename: $currentFilename") + log.info("Save root: $saveRoot") + driver.navigate().to(url) + driver.navigate().refresh() + Thread.sleep(5000) // Wait for javascript to load + + htmlPages += mutableMapOf((currentFilename ?: url.file.split("/").last()) to editPage(driver.pageSource)) + val baseUrl = url.toString().split("#").first() + links += toAbsolute(baseUrl, *currentPageLinks(driver).map { link -> + val relative = toRelative(baseUrl, link) ?: return@map link + linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" + linkReplacements[relative] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" + link + }.toTypedArray()).toMutableList() + val completionSemaphores = mutableListOf() + + log.info("Fetching page source") + log.info("Base URL: $baseUrl") + val coveredLinks = mutableSetOf() + log.info("Processing links") + while (links.isNotEmpty()) { + val href = links.removeFirst() + try { + if (coveredLinks.contains(href)) continue + coveredLinks += href + log.debug("Processing $href") + process(url, href, completionSemaphores, saveRoot) + } catch (e: Exception) { + log.warn("Error processing $href", e) + } } - private val linkReplacements = mutableMapOf() - private val htmlPages: MutableMap = mutableMapOf() - private val jsonPages = mutableMapOf() - private val links: MutableList = mutableListOf() - - override fun save( - url: URL, - currentFilename: String?, - saveRoot: String - ) { - log.info("Saving URL: $url") - log.info("Current filename: $currentFilename") - log.info("Save root: $saveRoot") - driver.navigate().to(url) - driver.navigate().refresh() - Thread.sleep(5000) // Wait for javascript to load - - htmlPages += mutableMapOf((currentFilename ?: url.file.split("/").last()) to editPage(driver.pageSource)) - val baseUrl = url.toString().split("#").first() - links += toAbsolute(baseUrl, *currentPageLinks(driver).map { link -> - val relative = toRelative(baseUrl, link) ?: return@map link - linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" - linkReplacements[relative] = "${cloud!!.shareBase}/$saveRoot/${toArchivePath(relative)}" - link - }.toTypedArray()).toMutableList() - val completionSemaphores = mutableListOf() - - log.info("Fetching page source") - log.info("Base URL: $baseUrl") - val coveredLinks = mutableSetOf() - log.info("Processing links") - while (links.isNotEmpty()) { - val href = links.removeFirst() - try { - if (coveredLinks.contains(href)) continue - coveredLinks += href - log.debug("Processing $href") - process(url, href, completionSemaphores, saveRoot) - } catch (e: Exception) { - log.warn("Error processing $href", e) - } - } - - log.info("Fetching current page links") - log.debug("Waiting for completion") - completionSemaphores.forEach { it.acquire(); it.release() } - - log.debug("Saving") - saveAll(saveRoot) - log.debug("Done") + log.info("Fetching current page links") + log.debug("Waiting for completion") + completionSemaphores.forEach { it.acquire(); it.release() } + + log.debug("Saving") + saveAll(saveRoot) + log.debug("Done") + } + + protected open fun process( + url: URL, + href: String, + completionSemaphores: MutableList, + saveRoot: String + ): Boolean { + val base = url.toString().split("/").dropLast(1).joinToString("/") + val relative = toArchivePath(toRelative(base, href) ?: return true) + when (val mimeType = mimeType(relative)) { + + "text/html" -> { + if (htmlPages.containsKey(relative)) return true + log.info("Fetching $href") + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getHtml(href, htmlPages, relative, links, saveRoot, semaphore) + } + + "application/json" -> { + if (jsonPages.containsKey(relative)) return true + log.info("Fetching $href") + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getJson(href, jsonPages, relative, semaphore) + } + + else -> { + val semaphore = Semaphore(0) + completionSemaphores += semaphore + getMedia(href, mimeType, saveRoot, relative, semaphore) + } } - - protected open fun process( - url: URL, - href: String, - completionSemaphores: MutableList, - saveRoot: String - ): Boolean { - val base = url.toString().split("/").dropLast(1).joinToString("/") - val relative = toArchivePath(toRelative(base, href) ?: return true) - when (val mimeType = mimeType(relative)) { - - "text/html" -> { - if (htmlPages.containsKey(relative)) return true - log.info("Fetching $href") - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getHtml(href, htmlPages, relative, links, saveRoot, semaphore) - } - - "application/json" -> { - if (jsonPages.containsKey(relative)) return true - log.info("Fetching $href") - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getJson(href, jsonPages, relative, semaphore) - } - - else -> { - val semaphore = Semaphore(0) - completionSemaphores += semaphore - getMedia(href, mimeType, saveRoot, relative, semaphore) - } + return false + } + + protected open fun getHtml( + href: String, + htmlPages: MutableMap, + relative: String, + links: MutableList, + saveRoot: String, + semaphore: Semaphore + ) { + httpClient.execute(get(href), object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + log.debug("Fetched $href") + val html = p0?.body?.bodyText ?: "" + htmlPages[relative] = html + links += toAbsolute(href, *currentPageLinks(html).map { link -> + val relative = toArchivePath(toRelative(href, link) ?: return@map link) + linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/$relative" + link + }.toTypedArray()) + semaphore.release() + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + protected open fun getJson( + href: String, + jsonPages: MutableMap, + relative: String, + semaphore: Semaphore + ) { + httpClient.execute(get(href), object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + log.debug("Fetched $href") + jsonPages[relative] = p0?.body?.bodyText ?: "" + semaphore.release() + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + protected open fun getMedia( + href: String, + mimeType: String, + saveRoot: String, + relative: String, + semaphore: Semaphore + ) { + val request = get(href) + httpClient.execute(request, object : FutureCallback { + + override fun completed(p0: SimpleHttpResponse?) { + try { + log.debug("Fetched $request") + val bytes = p0?.body?.bodyBytes ?: return + if (validate(mimeType, p0.body.contentType.mimeType, bytes)) + cloud!!.upload( + path = "/$saveRoot/$relative", + contentType = mimeType, + bytes = bytes + ) + } finally { + semaphore.release() } - return false - } - - protected open fun getHtml( - href: String, - htmlPages: MutableMap, - relative: String, - links: MutableList, - saveRoot: String, - semaphore: Semaphore - ) { - httpClient.execute(get(href), object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - log.debug("Fetched $href") - val html = p0?.body?.bodyText ?: "" - htmlPages[relative] = html - links += toAbsolute(href, *currentPageLinks(html).map { link -> - val relative = toArchivePath(toRelative(href, link) ?: return@map link) - linkReplacements[link] = "${cloud!!.shareBase}/$saveRoot/$relative" - link - }.toTypedArray()) - semaphore.release() - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - protected open fun getJson( - href: String, - jsonPages: MutableMap, - relative: String, - semaphore: Semaphore - ) { - httpClient.execute(get(href), object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - log.debug("Fetched $href") - jsonPages[relative] = p0?.body?.bodyText ?: "" - semaphore.release() - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - protected open fun getMedia( - href: String, - mimeType: String, - saveRoot: String, - relative: String, - semaphore: Semaphore - ) { - val request = get(href) - httpClient.execute(request, object : FutureCallback { - - override fun completed(p0: SimpleHttpResponse?) { - try { - log.debug("Fetched $request") - val bytes = p0?.body?.bodyBytes ?: return - if (validate(mimeType, p0.body.contentType.mimeType, bytes)) - cloud!!.upload( - path = "/$saveRoot/$relative", - contentType = mimeType, - bytes = bytes - ) - } finally { - semaphore.release() - } - } - - override fun failed(p0: java.lang.Exception?) { - log.info("Error fetching $href", p0) - semaphore.release() - } - - override fun cancelled() { - log.info("Cancelled fetching $href") - semaphore.release() - } - - }) - } - - private fun saveAll( - saveRoot: String - ) { - (htmlPages.map { (filename, html) -> - pool.submit { - try { - saveHTML(html, saveRoot, filename) - } catch (e: Exception) { - log.warn("Error processing $filename", e) - } - } - } + jsonPages.map { (filename, js) -> - pool.submit { - try { - saveJS(js, saveRoot, filename) - } catch (e: Exception) { - log.warn("Error processing $filename", e) - } - } - }).forEach { - try { - it.get() - } catch (e: Exception) { - log.warn("Error processing", e) - } + } + + override fun failed(p0: java.lang.Exception?) { + log.info("Error fetching $href", p0) + semaphore.release() + } + + override fun cancelled() { + log.info("Cancelled fetching $href") + semaphore.release() + } + + }) + } + + private fun saveAll( + saveRoot: String + ) { + (htmlPages.map { (filename, html) -> + pool.submit { + try { + saveHTML(html, saveRoot, filename) + } catch (e: Exception) { + log.warn("Error processing $filename", e) } - } - - protected open fun saveJS(js: String, saveRoot: String, filename: String) { - val finalJs = linkReplacements.toList().sortedBy { it.first.length } - .fold(js) { acc, (href, relative) -> //language=RegExp - acc.replace("""(? acc.replace("""(? - request.addHeader("Cookie", "${cookie.name}=${cookie.value}") + } + } + jsonPages.map { (filename, js) -> + pool.submit { + try { + saveJS(js, saveRoot, filename) + } catch (e: Exception) { + log.warn("Error processing $filename", e) } - return request + } + }).forEach { + try { + it.get() + } catch (e: Exception) { + log.warn("Error processing", e) + } } - - protected open fun currentPageLinks(driver: WebDriver) = listOf( - driver.findElements(By.xpath("//a[@href]")).map { it?.getAttribute("href") }.toSet(), - driver.findElements(By.xpath("//img[@src]")).map { it?.getAttribute("src") }.toSet(), - driver.findElements(By.xpath("//link[@href]")).map { it?.getAttribute("href") }.toSet(), - driver.findElements(By.xpath("//script[@src]")).map { it?.getAttribute("src") }.toSet(), - driver.findElements(By.xpath("//source[@src]")).map { it?.getAttribute("src") }.toSet(), - ).flatten().filterNotNull() - - private fun currentPageLinks(html: String) = listOf( - Jsoup.parse(html).select("a[href]").map { it.attr("href") }.toSet(), - Jsoup.parse(html).select("img[src]").map { it.attr("src") }.toSet(), - Jsoup.parse(html).select("link[href]").map { it.attr("href") }.toSet(), - Jsoup.parse(html).select("script[src]").map { it.attr("src") }.toSet(), - Jsoup.parse(html).select("source[src]").map { it.attr("src") }.toSet(), - ).flatten().filterNotNull() - - protected open fun toAbsolute(base: String, vararg links: String) = links - .map { it.split("#").first() }.filter { it.isNotBlank() }.distinct() - .map { link -> - val newLink = when { - link.startsWith("http") -> link - else -> URI.create(base).resolve(link).toString() - } - newLink - } - - protected open fun toRelative(base: String, link: String): String? = when { - link.startsWith(base) -> toRelative( - base, - link.removePrefix(base).replace("/{2,}".toRegex(), "/").removePrefix("/") - ) // relativize - link.startsWith("http") -> null // absolute - else -> link // relative + } + + protected open fun saveJS(js: String, saveRoot: String, filename: String) { + val finalJs = linkReplacements.toList().sortedBy { it.first.length } + .fold(js) { acc, (href, relative) -> //language=RegExp + acc.replace("""(? acc.replace("""(? + request.addHeader("Cookie", "${cookie.name}=${cookie.value}") } - - protected open fun toArchivePath(link: String): String = when { - link.startsWith("fileIndex") -> link.split("/").drop(2).joinToString("/") // rm file segment - else -> link + return request + } + + protected open fun currentPageLinks(driver: WebDriver) = listOf( + driver.findElements(By.xpath("//a[@href]")).map { it?.getAttribute("href") }.toSet(), + driver.findElements(By.xpath("//img[@src]")).map { it?.getAttribute("src") }.toSet(), + driver.findElements(By.xpath("//link[@href]")).map { it?.getAttribute("href") }.toSet(), + driver.findElements(By.xpath("//script[@src]")).map { it?.getAttribute("src") }.toSet(), + driver.findElements(By.xpath("//source[@src]")).map { it?.getAttribute("src") }.toSet(), + ).flatten().filterNotNull() + + private fun currentPageLinks(html: String) = listOf( + Jsoup.parse(html).select("a[href]").map { it.attr("href") }.toSet(), + Jsoup.parse(html).select("img[src]").map { it.attr("src") }.toSet(), + Jsoup.parse(html).select("link[href]").map { it.attr("href") }.toSet(), + Jsoup.parse(html).select("script[src]").map { it.attr("src") }.toSet(), + Jsoup.parse(html).select("source[src]").map { it.attr("src") }.toSet(), + ).flatten().filterNotNull() + + protected open fun toAbsolute(base: String, vararg links: String) = links + .map { it.split("#").first() }.filter { it.isNotBlank() }.distinct() + .map { link -> + val newLink = when { + link.startsWith("http") -> link + else -> URI.create(base).resolve(link).toString() + } + newLink } - protected open fun validate( - expected: String, - actual: String, - bytes: ByteArray - ): Boolean { - if (!actual.startsWith(expected)) { - log.warn("Content type mismatch: $actual != $expected") - if (actual.startsWith("text/html")) { - log.warn("Response Error: ${String(bytes)}", Exception()) - } - return false - } - return true + protected open fun toRelative(base: String, link: String): String? = when { + link.startsWith(base) -> toRelative( + base, + link.removePrefix(base).replace("/{2,}".toRegex(), "/").removePrefix("/") + ) // relativize + link.startsWith("http") -> null // absolute + else -> link // relative + } + + protected open fun toArchivePath(link: String): String = when { + link.startsWith("fileIndex") -> link.split("/").drop(2).joinToString("/") // rm file segment + else -> link + } + + protected open fun validate( + expected: String, + actual: String, + bytes: ByteArray + ): Boolean { + if (!actual.startsWith(expected)) { + log.warn("Content type mismatch: $actual != $expected") + if (actual.startsWith("text/html")) { + log.warn("Response Error: ${String(bytes)}", Exception()) + } + return false } - - protected open fun mimeType(relative: String): String { - val extension = relative.split(".").last().split("?").first() - val contentType = when (extension) { - "css" -> "text/css" - "js" -> "text/javascript" - "json" -> "application/json" - "pdf" -> "application/pdf" - "zip" -> "application/zip" - "tar" -> "application/x-tar" - "gz" -> "application/gzip" - "bz2" -> "application/bzip2" - "mp3" -> "audio/mpeg" - //"tsv" -> "text/tab-separated-values" - "csv" -> "text/csv" - "txt" -> "text/plain" - "xml" -> "text/xml" - "svg" -> "image/svg+xml" - "png" -> "image/png" - "jpg" -> "image/jpeg" - "jpeg" -> "image/jpeg" - "gif" -> "image/gif" - "ico" -> "image/x-icon" - "html" -> "text/html" - "htm" -> "text/html" - else -> "text/plain" - } - return contentType + return true + } + + protected open fun mimeType(relative: String): String { + val extension = relative.split(".").last().split("?").first() + val contentType = when (extension) { + "css" -> "text/css" + "js" -> "text/javascript" + "json" -> "application/json" + "pdf" -> "application/pdf" + "zip" -> "application/zip" + "tar" -> "application/x-tar" + "gz" -> "application/gzip" + "bz2" -> "application/bzip2" + "mp3" -> "audio/mpeg" + //"tsv" -> "text/tab-separated-values" + "csv" -> "text/csv" + "txt" -> "text/plain" + "xml" -> "text/xml" + "svg" -> "image/svg+xml" + "png" -> "image/png" + "jpg" -> "image/jpeg" + "jpeg" -> "image/jpeg" + "gif" -> "image/gif" + "ico" -> "image/x-icon" + "html" -> "text/html" + "htm" -> "text/html" + else -> "text/plain" } - - protected open fun editPage(html: String): String { - val doc = Jsoup.parse(html) - doc.select("#toolbar").remove() - doc.select("#namebar").remove() - doc.select("#main-input").remove() - doc.select("#footer").remove() - return doc.toString() + return contentType + } + + protected open fun editPage(html: String): String { + val doc = Jsoup.parse(html) + doc.select("#toolbar").remove() + doc.select("#namebar").remove() + doc.select("#main-input").remove() + doc.select("#footer").remove() + return doc.toString() + } + + override fun close() { + log.debug("Closing", Exception()) + driver.quit() + httpClient.close() + //driver.close() + //Companion.chromeDriverService.close() + } + + + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(Selenium2S3::class.java) + + init { + Runtime.getRuntime().addShutdownHook(Thread { + try { + } catch (e: Exception) { + log.warn("Error closing com.simiacryptus.skyenet.webui.util.Selenium2S3", e) + } + }) } - override fun close() { - log.debug("Closing", Exception()) - driver.quit() - httpClient.close() - //driver.close() - //Companion.chromeDriverService.close() + fun chromeDriver(headless: Boolean = true, loadImages: Boolean = !headless): ChromeDriver { + val osname = System.getProperty("os.name") + val chromePath = when { + // Windows + osname.contains("Windows") -> listOf( + "C:\\Program Files\\Google\\Chrome\\Application\\chromedriver.exe", + "C:\\Program Files (x86)\\Google\\Chrome\\Application\\chromedriver.exe" + ) + // Ubuntu + osname.contains("Linux") -> listOf("/usr/bin/chromedriver") + else -> throw RuntimeException("Not implemented for $osname") + } + System.setProperty("webdriver.chrome.driver", + chromePath.find { File(it).exists() } ?: throw RuntimeException("Chrome not found")) + val options = ChromeOptions() + val args = mutableListOf() + if (headless) args += "--headless" + if (loadImages) args += "--blink-settings=imagesEnabled=false" + options.addArguments(*args.toTypedArray()) + options.setPageLoadTimeout(Duration.of(90, ChronoUnit.SECONDS)) + return ChromeDriver(chromeDriverService, options) } - - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(Selenium2S3::class.java) - - init { - Runtime.getRuntime().addShutdownHook(Thread { - try { - } catch (e: Exception) { - log.warn("Error closing com.simiacryptus.skyenet.webui.util.Selenium2S3", e) - } - }) - } - - fun chromeDriver(headless: Boolean = true, loadImages: Boolean = !headless): ChromeDriver { - val osname = System.getProperty("os.name") - val chromePath = when { - // Windows - osname.contains("Windows") -> listOf( - "C:\\Program Files\\Google\\Chrome\\Application\\chromedriver.exe", - "C:\\Program Files (x86)\\Google\\Chrome\\Application\\chromedriver.exe" - ) - // Ubuntu - osname.contains("Linux") -> listOf("/usr/bin/chromedriver") - else -> throw RuntimeException("Not implemented for $osname") - } - System.setProperty("webdriver.chrome.driver", - chromePath.find { File(it).exists() } ?: throw RuntimeException("Chrome not found")) - val options = ChromeOptions() - val args = mutableListOf() - if (headless) args += "--headless" - if (loadImages) args += "--blink-settings=imagesEnabled=false" - options.addArguments(*args.toTypedArray()) - options.setPageLoadTimeout(Duration.of(90, ChronoUnit.SECONDS)) - return ChromeDriver(chromeDriverService, options) - } - - private val chromeDriverService by lazy { ChromeDriverService.createDefaultService() } - fun setCookies( - driver: WebDriver, - cookies: Array?, - domain: String? = null - ) { - cookies?.forEach { cookie -> - try { - driver.manage().addCookie( - Cookie( - /* name = */ cookie.name, - /* value = */ cookie.value, - /* domain = */ cookie.domain ?: domain, - /* path = */ cookie.path, - /* expiry = */ if (cookie.maxAge <= 0) null else Date(cookie.maxAge * 1000L), - /* isSecure = */ cookie.secure, - /* isHttpOnly = */ cookie.isHttpOnly - ) - ) - } catch (e: Exception) { - log.warn("Error setting cookie: $cookie", e) - } - } + private val chromeDriverService by lazy { ChromeDriverService.createDefaultService() } + fun setCookies( + driver: WebDriver, + cookies: Array?, + domain: String? = null + ) { + cookies?.forEach { cookie -> + try { + driver.manage().addCookie( + Cookie( + /* name = */ cookie.name, + /* value = */ cookie.value, + /* domain = */ cookie.domain ?: domain, + /* path = */ cookie.path, + /* expiry = */ if (cookie.maxAge <= 0) null else Date(cookie.maxAge * 1000L), + /* isSecure = */ cookie.secure, + /* isHttpOnly = */ cookie.isHttpOnly + ) + ) + } catch (e: Exception) { + log.warn("Error setting cookie: $cookie", e) } + } } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/TensorflowProjector.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/TensorflowProjector.kt index aedb1c88..9ce225ec 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/util/TensorflowProjector.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/util/TensorflowProjector.kt @@ -16,90 +16,93 @@ import java.io.IOException import kotlin.jvm.Throws class TensorflowProjector( - val api: API, - val dataStorage: StorageInterface, - val sessionID: Session, - val session: ApplicationInterface, - val userId: User?, - private val iframeHeight: Int = 500, - private val iframeWidth: String = "100%" + val api: API, + val dataStorage: StorageInterface, + val sessionID: Session, + val session: ApplicationInterface, + val userId: User?, + private val iframeHeight: Int = 500, + private val iframeWidth: String = "100%" ) { - companion object { - private const val VECTOR_FILENAME = "vectors.tsv" - private const val METADATA_FILENAME = "metadata.tsv" - private const val CONFIG_FILENAME = "projector-config.json" - private const val PROJECTOR_URL = "https://projector.tensorflow.org/" - } - @Throws(IOException::class) + companion object { + private const val VECTOR_FILENAME = "vectors.tsv" + private const val METADATA_FILENAME = "metadata.tsv" + private const val CONFIG_FILENAME = "projector-config.json" + private const val PROJECTOR_URL = "https://projector.tensorflow.org/" + } - private fun toVectorMap(vararg words: String): Map { - val vectors = words.map { word -> - word to (api as OpenAIClient).createEmbedding( - EmbeddingRequest( - model = EmbeddingModels.AdaEmbedding.modelName, - input = word.trim(), - ) - ).data.first().embedding!! - } - return vectors.toMap() - } - @Throws(IOException::class) + @Throws(IOException::class) - fun writeTensorflowEmbeddingProjectorHtmlFromRecords(records: List): String { - val vectorMap = records - .filter { it.text != null && it.vector != null } - .associate { record -> - record.text!!.trim() to record.vector!! - } - require(vectorMap.isNotEmpty()) { "No valid records found with both text and vector" } - return writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap) + private fun toVectorMap(vararg words: String): Map { + val vectors = words.map { word -> + word to (api as OpenAIClient).createEmbedding( + EmbeddingRequest( + model = EmbeddingModels.AdaEmbedding.modelName, + input = word.trim(), + ) + ).data.first().embedding!! } - @Throws(IOException::class) + return vectors.toMap() + } - fun writeTensorflowEmbeddingProjectorHtml(vararg words: String): String { - val filteredWords = words.filter { it.isNotBlank() }.distinct() - require(filteredWords.isNotEmpty()) { "No valid words provided" } - val vectorMap = toVectorMap(*filteredWords.toTypedArray()) - return writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap) - } + @Throws(IOException::class) + + fun writeTensorflowEmbeddingProjectorHtmlFromRecords(records: List): String { + val vectorMap = records + .filter { it.text != null && it.vector != null } + .associate { record -> + record.text!!.trim() to record.vector!! + } + require(vectorMap.isNotEmpty()) { "No valid records found with both text and vector" } + return writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap) + } - private fun writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap: Map): String { - require(vectorMap.isNotEmpty()) { "Vector map cannot be empty" } - - val vectorTsv = vectorMap.map { (_, vector) -> - vector.joinToString(separator = "\t") { - "%.2E".format(it) - } - }.joinToString(separator = "\n") + @Throws(IOException::class) - val metadataTsv = vectorMap.keys.joinToString(separator = "\n") { - it.replace(Regex("\\s+"), " ").trim() - } + fun writeTensorflowEmbeddingProjectorHtml(vararg words: String): String { + val filteredWords = words.filter { it.isNotBlank() }.distinct() + require(filteredWords.isNotEmpty()) { "No valid words provided" } + val vectorMap = toVectorMap(*filteredWords.toTypedArray()) + return writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap) + } - val uuid = UUID.randomUUID().toString() - val sessionDir = dataStorage.getSessionDir(userId, sessionID) - sessionDir.resolve(VECTOR_FILENAME).writeText(vectorTsv) - sessionDir.resolve(METADATA_FILENAME).writeText(metadataTsv) - val vectorURL = cloud?.upload("projector/$sessionID/$uuid/$VECTOR_FILENAME", "text/plain", vectorTsv) - ?: throw IllegalStateException("Cloud storage not initialized") - val metadataURL = cloud?.upload("projector/$sessionID/$uuid/$METADATA_FILENAME", "text/plain", metadataTsv) + private fun writeTensorflowEmbeddingProjectorHtmlFromVectorMap(vectorMap: Map): String { + require(vectorMap.isNotEmpty()) { "Vector map cannot be empty" } - val projectorConfig = JsonUtil.toJson( - mapOf( - "embeddings" to listOf( - mapOf( - "tensorName" to "embedding", - "tensorShape" to listOf(vectorMap.size, vectorMap.values.first().size), - "tensorPath" to vectorURL, - "metadataPath" to metadataURL, - ) - ) - ) + val vectorTsv = vectorMap.map { (_, vector) -> + vector.joinToString(separator = "\t") { + "%.2E".format(it) + } + }.joinToString(separator = "\n") + + val metadataTsv = vectorMap.keys.joinToString(separator = "\n") { + it.replace(Regex("\\s+"), " ").trim() + } + + val uuid = UUID.randomUUID().toString() + val sessionDir = dataStorage.getSessionDir(userId, sessionID) + sessionDir.resolve(VECTOR_FILENAME).writeText(vectorTsv) + sessionDir.resolve(METADATA_FILENAME).writeText(metadataTsv) + val vectorURL = cloud?.upload("projector/$sessionID/$uuid/$VECTOR_FILENAME", "text/plain", vectorTsv) + ?: throw IllegalStateException("Cloud storage not initialized") + val metadataURL = cloud?.upload("projector/$sessionID/$uuid/$METADATA_FILENAME", "text/plain", metadataTsv) + + val projectorConfig = JsonUtil.toJson( + mapOf( + "embeddings" to listOf( + mapOf( + "tensorName" to "embedding", + "tensorShape" to listOf(vectorMap.size, vectorMap.values.first().size), + "tensorPath" to vectorURL, + "metadataPath" to metadataURL, + ) ) - sessionDir.resolve(CONFIG_FILENAME).writeText(projectorConfig) - val configURL = cloud?.upload("projector/$sessionID/$CONFIG_FILENAME", "application/json", projectorConfig) + ) + ) + sessionDir.resolve(CONFIG_FILENAME).writeText(projectorConfig) + val configURL = cloud?.upload("projector/$sessionID/$CONFIG_FILENAME", "application/json", projectorConfig) - return """ + return """
""".trimIndent() - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/AppInfoData.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/AppInfoData.kt index 2e776159..20038941 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/AppInfoData.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/AppInfoData.kt @@ -1,14 +1,14 @@ package com.simiacryptus.skyenet.webui.application data class AppInfoData( - val applicationName: String, - val singleInput: Boolean, - val stickyInput: Boolean, - val loadImages: Boolean, - val showMenubar: Boolean + val applicationName: String, + val singleInput: Boolean, + val stickyInput: Boolean, + val loadImages: Boolean, + val showMenubar: Boolean ) { - fun toMap(): Map { - return this::class.java.declaredFields.associate { it.name to it.get(this) } - } + fun toMap(): Map { + return this::class.java.declaredFields.associate { it.name to it.get(this) } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt index 6ad09dba..6f8b6b2c 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationDirectory.kt @@ -32,209 +32,209 @@ import kotlin.system.exitProcess abstract class ApplicationDirectory( - val localName: String = "localhost", - val publicName: String = "localhost", - val port: Int = 8081, + val localName: String = "localhost", + val publicName: String = "localhost", + val port: Int = 8081, ) { - var domainName: String = "" // Resolved in _main - private set - abstract val childWebApps: List - - data class ChildWebApp( - val path: String, - val server: ChatServer, - val thumbnail: String? = null, - ) - - private fun domainName(isServer: Boolean) = - if (isServer) "https://$publicName" else "http://$localName:$port" - - open val welcomeResources = ResourceCollection(allResources("welcome").map(::newResource)) - open val userInfoServlet: HttpServlet = UserInfoServlet() - open val userSettingsServlet: HttpServlet = UserSettingsServlet() - open val logoutServlet: HttpServlet = LogoutServlet() - open val usageServlet: HttpServlet = UsageServlet() - open val proxyHttpServlet: HttpServlet = ProxyHttpServlet() - open val apiKeyServlet: HttpServlet = ApiKeyServlet() - open val welcomeServlet: HttpServlet = WelcomeServlet(this) - - open fun authenticatedWebsite(): OAuthBase? = OAuthGoogle( - redirectUri = "$domainName/oauth2callback", - applicationName = "Demo", - key = { - val encryptedData = - javaClass.classLoader!!.getResourceAsStream("client_secret_google_oauth.json.kms")?.readAllBytes() - ?: throw RuntimeException("Unable to load resource: ${"client_secret_google_oauth.json.kms"}") - ApplicationServices.cloud!!.decrypt(encryptedData).byteInputStream() - } - ) - - open fun setupPlatform() { - ApplicationServices.seleniumFactory = { pool, cookies -> - Selenium2S3(pool, cookies) - } + var domainName: String = "" // Resolved in _main + private set + abstract val childWebApps: List + + data class ChildWebApp( + val path: String, + val server: ChatServer, + val thumbnail: String? = null, + ) + + private fun domainName(isServer: Boolean) = + if (isServer) "https://$publicName" else "http://$localName:$port" + + open val welcomeResources = ResourceCollection(allResources("welcome").map(::newResource)) + open val userInfoServlet: HttpServlet = UserInfoServlet() + open val userSettingsServlet: HttpServlet = UserSettingsServlet() + open val logoutServlet: HttpServlet = LogoutServlet() + open val usageServlet: HttpServlet = UsageServlet() + open val proxyHttpServlet: HttpServlet = ProxyHttpServlet() + open val apiKeyServlet: HttpServlet = ApiKeyServlet() + open val welcomeServlet: HttpServlet = WelcomeServlet(this) + + open fun authenticatedWebsite(): OAuthBase? = OAuthGoogle( + redirectUri = "$domainName/oauth2callback", + applicationName = "Demo", + key = { + val encryptedData = + javaClass.classLoader!!.getResourceAsStream("client_secret_google_oauth.json.kms")?.readAllBytes() + ?: throw RuntimeException("Unable to load resource: ${"client_secret_google_oauth.json.kms"}") + ApplicationServices.cloud!!.decrypt(encryptedData).byteInputStream() } + ) - protected open fun _main(args: Array) { + open fun setupPlatform() { + ApplicationServices.seleniumFactory = { pool, cookies -> + Selenium2S3(pool, cookies) + } + } + + protected open fun _main(args: Array) { + try { + log.info("Starting application with args: ${args.joinToString(", ")}") + setupPlatform() + init(args.contains("--server")) + if (ClientUtil.keyTxt.isEmpty()) ClientUtil.keyTxt = run { try { - log.info("Starting application with args: ${args.joinToString(", ")}") - setupPlatform() - init(args.contains("--server")) - if (ClientUtil.keyTxt.isEmpty()) ClientUtil.keyTxt = run { - try { - val encryptedData = javaClass.classLoader.getResourceAsStream("openai.key.json.kms")?.readAllBytes() - ?: throw RuntimeException("Unable to load resource: ${"openai.key.json.kms"}") - val decrypt = ApplicationServices.cloud!!.decrypt(encryptedData) - JsonUtil.fromJson(decrypt, Map::class.java) - } catch (e: Throwable) { - log.warn("Error loading key.txt", e) - "" - } - } - isLocked = true - val server = start( - port, - *(webAppContexts()) - ) - log.info("Server started successfully on port $port") - try { - Desktop.getDesktop().browse(URI("$domainName/")) - } catch (e: Throwable) { - // Ignore - } - server.join() + val encryptedData = javaClass.classLoader.getResourceAsStream("openai.key.json.kms")?.readAllBytes() + ?: throw RuntimeException("Unable to load resource: ${"openai.key.json.kms"}") + val decrypt = ApplicationServices.cloud!!.decrypt(encryptedData) + JsonUtil.fromJson(decrypt, Map::class.java) } catch (e: Throwable) { - e.printStackTrace() - log.error("Application encountered an error: ${e.message}", e) - Thread.sleep(1000) - exitProcess(1) - } finally { - Thread.sleep(1000) - exitProcess(0) + log.warn("Error loading key.txt", e) + "" } + } + isLocked = true + val server = start( + port, + *(webAppContexts()) + ) + log.info("Server started successfully on port $port") + try { + Desktop.getDesktop().browse(URI("$domainName/")) + } catch (e: Throwable) { + // Ignore + } + server.join() + } catch (e: Throwable) { + e.printStackTrace() + log.error("Application encountered an error: ${e.message}", e) + Thread.sleep(1000) + exitProcess(1) + } finally { + Thread.sleep(1000) + exitProcess(0) } - - open fun webAppContexts() = listOfNotNull( - newWebAppContext("/logout", logoutServlet), - newWebAppContext("/proxy", proxyHttpServlet), - // toolServlet?.let { newWebAppContext("/tools", it) }, - newWebAppContext("/userInfo", userInfoServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/userSettings", userSettingsServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/usage", usageServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/apiKeys", apiKeyServlet).let { - authenticatedWebsite()?.configure(it, true) ?: it - }, - newWebAppContext("/", welcomeResources, "welcome", welcomeServlet).let { - authenticatedWebsite()?.configure(it, false) ?: it - }, - newWebAppContext("/api", welcomeServlet).let { - authenticatedWebsite()?.configure(it, false) ?: it - }, - ).toTypedArray() + childWebApps.map { - newWebAppContext(it.path, it.server) - } - - open fun init(isServer: Boolean): ApplicationDirectory { - OutputInterceptor.setupInterceptor() - log.info("Initializing application, isServer: $isServer") - domainName = domainName(isServer) - return this - } - - protected open fun start( - port: Int, - vararg webAppContexts: WebAppContext - ): Server { - val contexts = ContextHandlerCollection() + } + + open fun webAppContexts() = listOfNotNull( + newWebAppContext("/logout", logoutServlet), + newWebAppContext("/proxy", proxyHttpServlet), + // toolServlet?.let { newWebAppContext("/tools", it) }, + newWebAppContext("/userInfo", userInfoServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/userSettings", userSettingsServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/usage", usageServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/apiKeys", apiKeyServlet).let { + authenticatedWebsite()?.configure(it, true) ?: it + }, + newWebAppContext("/", welcomeResources, "welcome", welcomeServlet).let { + authenticatedWebsite()?.configure(it, false) ?: it + }, + newWebAppContext("/api", welcomeServlet).let { + authenticatedWebsite()?.configure(it, false) ?: it + }, + ).toTypedArray() + childWebApps.map { + newWebAppContext(it.path, it.server) + } + + open fun init(isServer: Boolean): ApplicationDirectory { + OutputInterceptor.setupInterceptor() + log.info("Initializing application, isServer: $isServer") + domainName = domainName(isServer) + return this + } + + protected open fun start( + port: Int, + vararg webAppContexts: WebAppContext + ): Server { + val contexts = ContextHandlerCollection() // val stats = StatisticsHandler() - log.info("Starting server on port: $port") - contexts.handlers = ( - listOf( - newWebAppContext("/stats", StatisticsServlet()) - ) + - webAppContexts.map { - it.addFilter(FilterHolder(CorsFilter()), "/*", EnumSet.of(DispatcherType.REQUEST)) - it - } - ).toTypedArray() - val server = Server(port) - // Increase the number of acceptors and selectors for better scalability in a non-blocking model - val serverConnector = ServerConnector(server, 4, 8, httpConnectionFactory()) - serverConnector.port = port - serverConnector.acceptQueueSize = 1000 - serverConnector.idleTimeout = 30000 // Set idle timeout to 30 seconds - server.connectors = arrayOf(serverConnector) - server.handler = contexts - server.start() - if (!server.isStarted) throw IllegalStateException("Server failed to start") - log.info("Server initialization completed successfully.") - return server - } - - protected open fun httpConnectionFactory(): HttpConnectionFactory { - val httpConfig = HttpConfiguration() - httpConfig.addCustomizer(ForwardedRequestCustomizer()) - log.debug("HTTP connection factory created with custom configuration.") - return HttpConnectionFactory(httpConfig) - } - - protected open fun newWebAppContext(path: String, server: ChatServer): WebAppContext { - val baseResource = server.baseResource ?: throw IllegalStateException("No base resource") - val webAppContext = newWebAppContext(path, baseResource, resourceBase = "application") - server.configure(webAppContext) - log.info("WebAppContext configured for path: $path with ChatServer") - return webAppContext - } - - protected open fun newWebAppContext( - path: String, - baseResource: Resource, - resourceBase: String, - indexServlet: Servlet? = null - ): WebAppContext { - val context = WebAppContext() - JettyWebSocketServletContainerInitializer.configure(context, null) - context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) - context.isParentLoaderPriority = true - context.baseResource = baseResource - log.debug("New WebAppContext created for path: $path") - context.contextPath = path + log.info("Starting server on port: $port") + contexts.handlers = ( + listOf( + newWebAppContext("/stats", StatisticsServlet()) + ) + + webAppContexts.map { + it.addFilter(FilterHolder(CorsFilter()), "/*", EnumSet.of(DispatcherType.REQUEST)) + it + } + ).toTypedArray() + val server = Server(port) + // Increase the number of acceptors and selectors for better scalability in a non-blocking model + val serverConnector = ServerConnector(server, 4, 8, httpConnectionFactory()) + serverConnector.port = port + serverConnector.acceptQueueSize = 1000 + serverConnector.idleTimeout = 30000 // Set idle timeout to 30 seconds + server.connectors = arrayOf(serverConnector) + server.handler = contexts + server.start() + if (!server.isStarted) throw IllegalStateException("Server failed to start") + log.info("Server initialization completed successfully.") + return server + } + + protected open fun httpConnectionFactory(): HttpConnectionFactory { + val httpConfig = HttpConfiguration() + httpConfig.addCustomizer(ForwardedRequestCustomizer()) + log.debug("HTTP connection factory created with custom configuration.") + return HttpConnectionFactory(httpConfig) + } + + protected open fun newWebAppContext(path: String, server: ChatServer): WebAppContext { + val baseResource = server.baseResource ?: throw IllegalStateException("No base resource") + val webAppContext = newWebAppContext(path, baseResource, resourceBase = "application") + server.configure(webAppContext) + log.info("WebAppContext configured for path: $path with ChatServer") + return webAppContext + } + + protected open fun newWebAppContext( + path: String, + baseResource: Resource, + resourceBase: String, + indexServlet: Servlet? = null + ): WebAppContext { + val context = WebAppContext() + JettyWebSocketServletContainerInitializer.configure(context, null) + context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) + context.isParentLoaderPriority = true + context.baseResource = baseResource + log.debug("New WebAppContext created for path: $path") + context.contextPath = path // context.resourceBase = resourceBase - context.welcomeFiles = arrayOf("index.html") - if (indexServlet != null) { - context.addServlet(ServletHolder("$path/index", indexServlet), "/") + context.welcomeFiles = arrayOf("index.html") + if (indexServlet != null) { + context.addServlet(ServletHolder("$path/index", indexServlet), "/") // context.addServlet(ServletHolder("$path/index", indexServlet), "/*") - context.addServlet(ServletHolder("$path/index", indexServlet), "/index.html") - } - return context - } - - protected open fun newWebAppContext(path: String, servlet: Servlet): WebAppContext { - val context = WebAppContext() - JettyWebSocketServletContainerInitializer.configure(context, null) - context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) - context.isParentLoaderPriority = true - context.contextPath = path - log.debug("New WebAppContext created for servlet at path: $path") - context.resourceBase = "application" - context.welcomeFiles = arrayOf("index.html") - val servletHolder = ServletHolder(servlet) - servletHolder.getRegistration().setMultipartConfig(MultipartConfigElement("./tmp")) - context.addServlet(servletHolder, "/") - return context - } - - - companion object { - private val log = LoggerFactory.getLogger(ApplicationDirectory::class.java) - fun allResources(resourceName: String) = - Thread.currentThread().contextClassLoader.getResources(resourceName).toList() + context.addServlet(ServletHolder("$path/index", indexServlet), "/index.html") } + return context + } + + protected open fun newWebAppContext(path: String, servlet: Servlet): WebAppContext { + val context = WebAppContext() + JettyWebSocketServletContainerInitializer.configure(context, null) + context.classLoader = WebAppClassLoader(ApplicationServices::class.java.classLoader, context) + context.isParentLoaderPriority = true + context.contextPath = path + log.debug("New WebAppContext created for servlet at path: $path") + context.resourceBase = "application" + context.welcomeFiles = arrayOf("index.html") + val servletHolder = ServletHolder(servlet) + servletHolder.getRegistration().setMultipartConfig(MultipartConfigElement("./tmp")) + context.addServlet(servletHolder, "/") + return context + } + + + companion object { + private val log = LoggerFactory.getLogger(ApplicationDirectory::class.java) + fun allResources(resourceName: String) = + Thread.currentThread().contextClassLoader.getResources(resourceName).toList() + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt index 84f9c526..77ed7682 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationInterface.kt @@ -8,40 +8,40 @@ import java.util.function.Consumer open class ApplicationInterface(val socketManager: SocketManagerBase?) { - open fun isInteractive() = true - - @Description("Returns html for a link that will trigger the given handler when clicked.") - open fun hrefLink( - @Description("The text to display in the link") - linkText: String, - @Description("The css class to apply to the link") - classname: String = """href-link""", - @Description("The id to apply to the link") - id: String? = null, - @Description("The handler to trigger when the link is clicked") - handler: Consumer, - ) = socketManager!!.hrefLink(linkText, classname, id, oneAtATime(handler)) - - @Description("Returns html for a text input form that will trigger the given handler when submitted.") - open fun textInput( - @Description("The handler to trigger when the form is submitted") - handler: Consumer - ): String = socketManager!!.textInput(oneAtATime(handler)) - - @Description("Creates a new 'task' that can be used to display the progress of a long-running operation.") - open fun newTask( - root: Boolean = true - ): SessionTask = socketManager!!.newTask(cancelable = false, root = root) - - companion object { - fun oneAtATime(handler: Consumer): Consumer { - val guard = AtomicBoolean(false) - return Consumer { t -> - if (guard.getAndSet(true)) return@Consumer - handler.accept(t) - guard.set(false) - } - } + open fun isInteractive() = true + + @Description("Returns html for a link that will trigger the given handler when clicked.") + open fun hrefLink( + @Description("The text to display in the link") + linkText: String, + @Description("The css class to apply to the link") + classname: String = """href-link""", + @Description("The id to apply to the link") + id: String? = null, + @Description("The handler to trigger when the link is clicked") + handler: Consumer, + ) = socketManager!!.hrefLink(linkText, classname, id, oneAtATime(handler)) + + @Description("Returns html for a text input form that will trigger the given handler when submitted.") + open fun textInput( + @Description("The handler to trigger when the form is submitted") + handler: Consumer + ): String = socketManager!!.textInput(oneAtATime(handler)) + + @Description("Creates a new 'task' that can be used to display the progress of a long-running operation.") + open fun newTask( + root: Boolean = true + ): SessionTask = socketManager!!.newTask(cancelable = false, root = root) + + companion object { + fun oneAtATime(handler: Consumer): Consumer { + val guard = AtomicBoolean(false) + return Consumer { t -> + if (guard.getAndSet(true)) return@Consumer + handler.accept(t) + guard.set(false) + } } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationServer.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationServer.kt index f30af0b6..09b138f1 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationServer.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationServer.kt @@ -22,173 +22,173 @@ import org.eclipse.jetty.webapp.WebAppContext import java.io.File abstract class ApplicationServer( - final override val applicationName: String, - val path: String, - resourceBase: String = "application", - open val root: File = dataStorageRoot, - val showMenubar: Boolean = true, + final override val applicationName: String, + val path: String, + resourceBase: String = "application", + open val root: File = dataStorageRoot, + val showMenubar: Boolean = true, ) : ChatServer(resourceBase) { - open val description: String = "" - open val singleInput = true - open val stickyInput = false - open fun appInfo(session: Session) = appInfoMap.getOrPut(session) { - AppInfoData( - applicationName = applicationName, - singleInput = singleInput, - stickyInput = stickyInput, - loadImages = false, - showMenubar = showMenubar - ) - }.toMap() - - final override val dataStorage: StorageInterface by lazy { dataStorageFactory(dataStorageRoot) } - - protected open val appInfoServlet by lazy { - ServletHolder("appInfo", AppInfoServlet { session -> - appInfo(Session(session!!)) - }) - } - protected open val userInfo by lazy { ServletHolder("userInfo", UserInfoServlet()) } - protected open val usageServlet by lazy { ServletHolder("usage", UsageServlet()) } - protected open val fileZip by lazy { ServletHolder("fileZip", ZipServlet(dataStorage)) } - protected open val fileIndex by lazy { ServletHolder("fileIndex", SessionFileServlet(dataStorage)) } - protected open val sessionSettingsServlet by lazy { ServletHolder("settings", SessionSettingsServlet(this)) } - protected open val sessionShareServlet by lazy { ServletHolder("share", SessionShareServlet(this)) } - protected open val sessionThreadsServlet by lazy { ServletHolder("threads", SessionThreadsServlet(this)) } - protected open val deleteSessionServlet by lazy { ServletHolder("delete", DeleteSessionServlet(this)) } - protected open val cancelSessionServlet by lazy { ServletHolder("cancel", CancelThreadsServlet(this)) } - - override fun newSession(user: User?, session: Session): SocketManager { - dataStorage.setJson( - user, session, "info.json", mapOf( - "session" to session.toString(), - "application" to applicationName, - "path" to path, - "startTime" to System.currentTimeMillis(), - ) - ) - return object : ApplicationSocketManager( - session = session, - owner = user, - dataStorage = dataStorage, - applicationClass = this@ApplicationServer::class.java, - ) { - override fun userMessage( - session: Session, - user: User?, - userMessage: String, - socketManager: ApplicationSocketManager, - api: API - ) = this@ApplicationServer.userMessage( - session = session, - user = user, - userMessage = userMessage, - ui = socketManager.applicationInterface, - api = api - ) - } - } - - open fun userMessage( + open val description: String = "" + open val singleInput = true + open val stickyInput = false + open fun appInfo(session: Session) = appInfoMap.getOrPut(session) { + AppInfoData( + applicationName = applicationName, + singleInput = singleInput, + stickyInput = stickyInput, + loadImages = false, + showMenubar = showMenubar + ) + }.toMap() + + final override val dataStorage: StorageInterface by lazy { dataStorageFactory(dataStorageRoot) } + + protected open val appInfoServlet by lazy { + ServletHolder("appInfo", AppInfoServlet { session -> + appInfo(Session(session!!)) + }) + } + protected open val userInfo by lazy { ServletHolder("userInfo", UserInfoServlet()) } + protected open val usageServlet by lazy { ServletHolder("usage", UsageServlet()) } + protected open val fileZip by lazy { ServletHolder("fileZip", ZipServlet(dataStorage)) } + protected open val fileIndex by lazy { ServletHolder("fileIndex", SessionFileServlet(dataStorage)) } + protected open val sessionSettingsServlet by lazy { ServletHolder("settings", SessionSettingsServlet(this)) } + protected open val sessionShareServlet by lazy { ServletHolder("share", SessionShareServlet(this)) } + protected open val sessionThreadsServlet by lazy { ServletHolder("threads", SessionThreadsServlet(this)) } + protected open val deleteSessionServlet by lazy { ServletHolder("delete", DeleteSessionServlet(this)) } + protected open val cancelSessionServlet by lazy { ServletHolder("cancel", CancelThreadsServlet(this)) } + + override fun newSession(user: User?, session: Session): SocketManager { + dataStorage.setJson( + user, session, "info.json", mapOf( + "session" to session.toString(), + "application" to applicationName, + "path" to path, + "startTime" to System.currentTimeMillis(), + ) + ) + return object : ApplicationSocketManager( + session = session, + owner = user, + dataStorage = dataStorage, + applicationClass = this@ApplicationServer::class.java, + ) { + override fun userMessage( session: Session, user: User?, userMessage: String, - ui: ApplicationInterface, + socketManager: ApplicationSocketManager, api: API - ): Unit = throw UnsupportedOperationException() - - open val settingsClass: Class<*> get() = Map::class.java - - open fun initSettings(session: Session): T? = null - - fun getSettings( - session: Session, - userId: User?, - @Suppress("UNCHECKED_CAST") clazz: Class = settingsClass as Class - ): T? { - val settingsFile = getSettingsFile(session, userId) - var settings: T? = if (settingsFile.exists()) JsonUtil.fromJson(settingsFile.readText(), clazz) else null - if (null == settings) { - val initSettings = initSettings(session) - if (null != initSettings) { - settingsFile.writeText(JsonUtil.toJson(initSettings)) - } - if (settingsFile.exists()) { - settings = JsonUtil.fromJson(settingsFile.readText(), clazz) - } - } - return settings + ) = this@ApplicationServer.userMessage( + session = session, + user = user, + userMessage = userMessage, + ui = socketManager.applicationInterface, + api = api + ) } - - fun getSettingsFile( - session: Session, - userId: User? - ): File { - val settingsFile = - dataStorage.getDataDir(userId, session).resolve("settings.json") - .apply { parentFile.mkdirs() } - return settingsFile + } + + open fun userMessage( + session: Session, + user: User?, + userMessage: String, + ui: ApplicationInterface, + api: API + ): Unit = throw UnsupportedOperationException() + + open val settingsClass: Class<*> get() = Map::class.java + + open fun initSettings(session: Session): T? = null + + fun getSettings( + session: Session, + userId: User?, + @Suppress("UNCHECKED_CAST") clazz: Class = settingsClass as Class + ): T? { + val settingsFile = getSettingsFile(session, userId) + var settings: T? = if (settingsFile.exists()) JsonUtil.fromJson(settingsFile.readText(), clazz) else null + if (null == settings) { + val initSettings = initSettings(session) + if (null != initSettings) { + settingsFile.writeText(JsonUtil.toJson(initSettings)) + } + if (settingsFile.exists()) { + settings = JsonUtil.fromJson(settingsFile.readText(), clazz) + } } - - protected open fun sessionsServlet(path: String) = - ServletHolder("sessionList", SessionListServlet(this.dataStorage, path, this)) - - override fun configure(webAppContext: WebAppContext) { - super.configure(webAppContext) - - webAppContext.addFilter( - FilterHolder { request, response, chain -> - val user = authenticationManager.getUser((request as HttpServletRequest).getCookie()) - val canRead = authorizationManager.isAuthorized( - applicationClass = this@ApplicationServer.javaClass, - user = user, - operationType = OperationType.Read - ) - if (canRead) { - chain?.doFilter(request, response) - } else { - response?.writer?.write("Access Denied") - (response as HttpServletResponse?)?.status = HttpServletResponse.SC_FORBIDDEN - } - }, "/*", null + return settings + } + + fun getSettingsFile( + session: Session, + userId: User? + ): File { + val settingsFile = + dataStorage.getDataDir(userId, session).resolve("settings.json") + .apply { parentFile.mkdirs() } + return settingsFile + } + + protected open fun sessionsServlet(path: String) = + ServletHolder("sessionList", SessionListServlet(this.dataStorage, path, this)) + + override fun configure(webAppContext: WebAppContext) { + super.configure(webAppContext) + + webAppContext.addFilter( + FilterHolder { request, response, chain -> + val user = authenticationManager.getUser((request as HttpServletRequest).getCookie()) + val canRead = authorizationManager.isAuthorized( + applicationClass = this@ApplicationServer.javaClass, + user = user, + operationType = OperationType.Read ) - - webAppContext.addServlet(appInfoServlet, "/appInfo") - webAppContext.addServlet(userInfo, "/userInfo") - webAppContext.addServlet(usageServlet, "/usage") - webAppContext.addServlet(fileIndex, "/fileIndex/*") - webAppContext.addServlet(fileZip, "/fileZip") - webAppContext.addServlet(sessionsServlet(path), "/sessions") - webAppContext.addServlet(sessionSettingsServlet, "/settings") - webAppContext.addServlet(sessionThreadsServlet, "/threads") - webAppContext.addServlet(sessionShareServlet, "/share") - webAppContext.addServlet(deleteSessionServlet, "/delete") - webAppContext.addServlet(cancelSessionServlet, "/cancel") - } - - companion object { - - fun getMimeType(filename: String): String = - when { - filename.endsWith(".html") -> "text/html" - filename.endsWith(".json") -> "application/json" - filename.endsWith(".js") -> "application/javascript" - filename.endsWith(".png") -> "image/png" - filename.endsWith(".jpg") -> "image/jpeg" - filename.endsWith(".jpeg") -> "image/jpeg" - filename.endsWith(".gif") -> "image/gif" - filename.endsWith(".svg") -> "image/svg+xml" - filename.endsWith(".css") -> "text/css" - filename.endsWith(".mp3") -> "audio/mpeg" - else -> "text/plain" - } - - fun HttpServletRequest.getCookie(name: String = AuthenticationInterface.AUTH_COOKIE) = - cookies?.find { it.name == name }?.value - - - val appInfoMap = mutableMapOf() - } + if (canRead) { + chain?.doFilter(request, response) + } else { + response?.writer?.write("Access Denied") + (response as HttpServletResponse?)?.status = HttpServletResponse.SC_FORBIDDEN + } + }, "/*", null + ) + + webAppContext.addServlet(appInfoServlet, "/appInfo") + webAppContext.addServlet(userInfo, "/userInfo") + webAppContext.addServlet(usageServlet, "/usage") + webAppContext.addServlet(fileIndex, "/fileIndex/*") + webAppContext.addServlet(fileZip, "/fileZip") + webAppContext.addServlet(sessionsServlet(path), "/sessions") + webAppContext.addServlet(sessionSettingsServlet, "/settings") + webAppContext.addServlet(sessionThreadsServlet, "/threads") + webAppContext.addServlet(sessionShareServlet, "/share") + webAppContext.addServlet(deleteSessionServlet, "/delete") + webAppContext.addServlet(cancelSessionServlet, "/cancel") + } + + companion object { + + fun getMimeType(filename: String): String = + when { + filename.endsWith(".html") -> "text/html" + filename.endsWith(".json") -> "application/json" + filename.endsWith(".js") -> "application/javascript" + filename.endsWith(".png") -> "image/png" + filename.endsWith(".jpg") -> "image/jpeg" + filename.endsWith(".jpeg") -> "image/jpeg" + filename.endsWith(".gif") -> "image/gif" + filename.endsWith(".svg") -> "image/svg+xml" + filename.endsWith(".css") -> "text/css" + filename.endsWith(".mp3") -> "audio/mpeg" + else -> "text/plain" + } + + fun HttpServletRequest.getCookie(name: String = AuthenticationInterface.AUTH_COOKIE) = + cookies?.find { it.name == name }?.value + + + val appInfoMap = mutableMapOf() + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt index 57678efd..f79fcb27 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/application/ApplicationSocketManager.kt @@ -9,43 +9,43 @@ import com.simiacryptus.skyenet.webui.chat.ChatSocket import com.simiacryptus.skyenet.webui.session.SocketManagerBase abstract class ApplicationSocketManager( - session: Session, - owner: User?, - dataStorage: StorageInterface?, - applicationClass: Class<*>, + session: Session, + owner: User?, + dataStorage: StorageInterface?, + applicationClass: Class<*>, ) : SocketManagerBase( - session = session, - dataStorage = dataStorage, - owner = owner, - applicationClass = applicationClass, + session = session, + dataStorage = dataStorage, + owner = owner, + applicationClass = applicationClass, ) { - override fun onRun(userMessage: String, socket: ChatSocket) { - userMessage( - session = session, - user = socket.user, - userMessage = userMessage, - socketManager = this, - api = ApplicationServices.clientManager.getChatClient( - session, - socket.user - ) - ) - } + override fun onRun(userMessage: String, socket: ChatSocket) { + userMessage( + session = session, + user = socket.user, + userMessage = userMessage, + socketManager = this, + api = ApplicationServices.clientManager.getChatClient( + session, + socket.user + ) + ) + } - open val applicationInterface by lazy { ApplicationInterface(this) } + open val applicationInterface by lazy { ApplicationInterface(this) } - abstract fun userMessage( - session: Session, - user: User?, - userMessage: String, - socketManager: ApplicationSocketManager, - api: API - ) + abstract fun userMessage( + session: Session, + user: User?, + userMessage: String, + socketManager: ApplicationSocketManager, + api: API + ) - companion object { - // val playButton: String get() = """""" + companion object { + // val playButton: String get() = """""" // val cancelButton: String get() = """""" // val regenButton: String get() = """""" - } + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt index e208bd67..91b95c5c 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatServer.kt @@ -18,66 +18,66 @@ import java.time.Duration abstract class ChatServer(private val resourceBase: String) { - abstract val applicationName: String - open val dataStorage: StorageInterface? = null - val sessions: MutableMap = mutableMapOf() + abstract val applicationName: String + open val dataStorage: StorageInterface? = null + val sessions: MutableMap = mutableMapOf() - inner class WebSocketHandler : JettyWebSocketServlet() { - override fun configure(factory: JettyWebSocketServletFactory) { - with(factory) { - isAutoFragment = false - idleTimeout = Duration.ofMinutes(10) - outputBufferSize = 1024 * 1024 - inputBufferSize = 1024 * 1024 - maxBinaryMessageSize = 1024 * 1024 - maxFrameSize = 1024 * 1024 - maxTextMessageSize = 1024 * 1024 - this.availableExtensionNames.remove("permessage-deflate") - } - factory.setCreator { req, resp -> - try { - if (req.parameterMap.containsKey("sessionId")) { - val session = Session(req.parameterMap["sessionId"]?.first()!!) - ChatSocket( - if (sessions.containsKey(session)) { - sessions[session]!! - } else { - val user = - authenticationManager.getUser(req.getCookie(AuthenticationInterface.AUTH_COOKIE)) - val sessionState = newSession(user, session) - sessions[session] = sessionState - sessionState - } - ) - } else { - throw IllegalArgumentException("sessionId is required") - } - } catch (e: Exception) { - log.debug("Error configuring websocket", e) - resp.sendError(500, e.message) - null - } - } + inner class WebSocketHandler : JettyWebSocketServlet() { + override fun configure(factory: JettyWebSocketServletFactory) { + with(factory) { + isAutoFragment = false + idleTimeout = Duration.ofMinutes(10) + outputBufferSize = 1024 * 1024 + inputBufferSize = 1024 * 1024 + maxBinaryMessageSize = 1024 * 1024 + maxFrameSize = 1024 * 1024 + maxTextMessageSize = 1024 * 1024 + this.availableExtensionNames.remove("permessage-deflate") + } + factory.setCreator { req, resp -> + try { + if (req.parameterMap.containsKey("sessionId")) { + val session = Session(req.parameterMap["sessionId"]?.first()!!) + ChatSocket( + if (sessions.containsKey(session)) { + sessions[session]!! + } else { + val user = + authenticationManager.getUser(req.getCookie(AuthenticationInterface.AUTH_COOKIE)) + val sessionState = newSession(user, session) + sessions[session] = sessionState + sessionState + } + ) + } else { + throw IllegalArgumentException("sessionId is required") + } + } catch (e: Exception) { + log.debug("Error configuring websocket", e) + resp.sendError(500, e.message) + null } + } } + } - abstract fun newSession(user: User?, session: Session): SocketManager + abstract fun newSession(user: User?, session: Session): SocketManager - open val baseResource: Resource? get() = Resource.newResource(javaClass.classLoader.getResource(resourceBase)) - private val newSessionServlet by lazy { NewSessionServlet() } - private val webSocketHandler by lazy { WebSocketHandler() } - private val defaultServlet by lazy { DefaultServlet() } + open val baseResource: Resource? get() = Resource.newResource(javaClass.classLoader.getResource(resourceBase)) + private val newSessionServlet by lazy { NewSessionServlet() } + private val webSocketHandler by lazy { WebSocketHandler() } + private val defaultServlet by lazy { DefaultServlet() } - open fun configure(webAppContext: WebAppContext) { - webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/default", defaultServlet), "/") - webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/ws", webSocketHandler), "/ws") - webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/newSession", newSessionServlet), "/newSession") - } + open fun configure(webAppContext: WebAppContext) { + webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/default", defaultServlet), "/") + webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/ws", webSocketHandler), "/ws") + webAppContext.addServlet(ServletHolder(javaClass.simpleName + "/newSession", newSessionServlet), "/newSession") + } - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ChatServer::class.java) - fun JettyServerUpgradeRequest.getCookie(name: String) = cookies?.find { it.name == name }?.value - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ChatServer::class.java) + fun JettyServerUpgradeRequest.getCookie(name: String) = cookies?.find { it.name == name }?.value + } } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt index 7f1444ee..e7290714 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocket.kt @@ -6,36 +6,36 @@ import org.eclipse.jetty.websocket.api.Session import org.eclipse.jetty.websocket.api.WebSocketAdapter class ChatSocket( - private val sessionState: SocketManager, + private val sessionState: SocketManager, ) : WebSocketAdapter() { - val user get() = SocketManagerBase.getUser(session) - - override fun onWebSocketConnect(session: Session) { - super.onWebSocketConnect(session) - //log.debug("{} - Socket connected: {}", session, session.remote) - sessionState.addSocket(this, session) - sessionState.getReplay().forEach { - try { - remote.sendString(it) - } catch (e: Exception) { - e.printStackTrace() - } - } + val user get() = SocketManagerBase.getUser(session) + + override fun onWebSocketConnect(session: Session) { + super.onWebSocketConnect(session) + //log.debug("{} - Socket connected: {}", session, session.remote) + sessionState.addSocket(this, session) + sessionState.getReplay().forEach { + try { + remote.sendString(it) + } catch (e: Exception) { + e.printStackTrace() + } } + } - override fun onWebSocketText(message: String) { - super.onWebSocketText(message) - sessionState.onWebSocketText(this, message) - } + override fun onWebSocketText(message: String) { + super.onWebSocketText(message) + sessionState.onWebSocketText(this, message) + } - override fun onWebSocketClose(statusCode: Int, reason: String?) { - super.onWebSocketClose(statusCode, reason) + override fun onWebSocketClose(statusCode: Int, reason: String?) { + super.onWebSocketClose(statusCode, reason) - sessionState.removeSocket(this) - } + sessionState.removeSocket(this) + } - companion object + companion object } diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt index 9f26436a..bee12e61 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/chat/ChatSocketManager.kt @@ -15,75 +15,75 @@ import com.simiacryptus.skyenet.webui.session.SocketManagerBase import java.util.* open class ChatSocketManager( - session: Session, - val model: ChatModel, - val userInterfacePrompt: String, - open val initialAssistantPrompt: String = "", - open val systemPrompt: String, - val api: ChatClient, - val temperature: Double = 0.3, - applicationClass: Class, - val storage: StorageInterface?, + session: Session, + val model: ChatModel, + val userInterfacePrompt: String, + open val initialAssistantPrompt: String = "", + open val systemPrompt: String, + val api: ChatClient, + val temperature: Double = 0.3, + applicationClass: Class, + val storage: StorageInterface?, ) : SocketManagerBase(session, storage, owner = null, applicationClass = applicationClass) { - init { - if (userInterfacePrompt.isNotBlank()) { - send("""aaa,
${MarkdownUtil.renderMarkdown(userInterfacePrompt)}
""") - } + init { + if (userInterfacePrompt.isNotBlank()) { + send("""aaa,
${MarkdownUtil.renderMarkdown(userInterfacePrompt)}
""") } + } - protected val messages by lazy { - val list = listOf( - ApiModel.ChatMessage(ApiModel.Role.system, systemPrompt.toContentList()), - ).toMutableList() - if (initialAssistantPrompt.isNotBlank()) list += - ApiModel.ChatMessage(ApiModel.Role.assistant, initialAssistantPrompt.toContentList()) - list - } + protected val messages by lazy { + val list = listOf( + ApiModel.ChatMessage(ApiModel.Role.system, systemPrompt.toContentList()), + ).toMutableList() + if (initialAssistantPrompt.isNotBlank()) list += + ApiModel.ChatMessage(ApiModel.Role.assistant, initialAssistantPrompt.toContentList()) + list + } - @Synchronized - override fun onRun(userMessage: String, socket: ChatSocket) { - val task = newTask() - val api = (api as ChatClient).getChildClient().apply { - val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") - createFile.second?.apply { - logStreams += this.outputStream().buffered() - task.verbose("API log: $this") - } - } - val responseContents = renderResponse(userMessage, task) - task.echo(responseContents) - messages += ApiModel.ChatMessage(ApiModel.Role.user, userMessage.toContentList()) - val messagesCopy = messages.toList() - try { - val ui = ApplicationInterface(this) - val process = { it: StringBuilder -> - val response = (api.chat( - ApiModel.ChatRequest( - messages = messagesCopy, - temperature = temperature, - model = model.modelName, - ), model - ).choices.first().message?.content.orEmpty()) - messages.dropLastWhile { it.role == ApiModel.Role.assistant } - messages += ApiModel.ChatMessage(ApiModel.Role.assistant, response.toContentList()) - val renderResponse = renderResponse(response, task) - onResponse(renderResponse, responseContents) - renderResponse - } - Retryable(ui, task, process) - } catch (e: Exception) { - log.info("Error in chat", e) - task.error(ApplicationInterface(this), e) - } + @Synchronized + override fun onRun(userMessage: String, socket: ChatSocket) { + val task = newTask() + val api = (api as ChatClient).getChildClient().apply { + val createFile = task.createFile(".logs/api-${UUID.randomUUID()}.log") + createFile.second?.apply { + logStreams += this.outputStream().buffered() + task.verbose("API log: $this") + } + } + val responseContents = renderResponse(userMessage, task) + task.echo(responseContents) + messages += ApiModel.ChatMessage(ApiModel.Role.user, userMessage.toContentList()) + val messagesCopy = messages.toList() + try { + val ui = ApplicationInterface(this) + val process = { it: StringBuilder -> + val response = (api.chat( + ApiModel.ChatRequest( + messages = messagesCopy, + temperature = temperature, + model = model.modelName, + ), model + ).choices.first().message?.content.orEmpty()) + messages.dropLastWhile { it.role == ApiModel.Role.assistant } + messages += ApiModel.ChatMessage(ApiModel.Role.assistant, response.toContentList()) + val renderResponse = renderResponse(response, task) + onResponse(renderResponse, responseContents) + renderResponse + } + Retryable(ui, task, process) + } catch (e: Exception) { + log.info("Error in chat", e) + task.error(ApplicationInterface(this), e) } + } - open fun renderResponse(response: String, task: SessionTask) = - """
${MarkdownUtil.renderMarkdown(response)}
""" + open fun renderResponse(response: String, task: SessionTask) = + """
${MarkdownUtil.renderMarkdown(response)}
""" - open fun onResponse(response: String, responseContents: String) {} + open fun onResponse(response: String, responseContents: String) {} - companion object { - private val log = org.slf4j.LoggerFactory.getLogger(ChatSocketManager::class.java) - } + companion object { + private val log = org.slf4j.LoggerFactory.getLogger(ChatSocketManager::class.java) + } } \ No newline at end of file diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt index b205dfbc..427c2f95 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/webui/servlet/ApiKeyServlet.kt @@ -20,143 +20,143 @@ import kotlin.reflect.typeOf class ApiKeyServlet : HttpServlet() { - data class ApiKeyRecord( - val owner: String, - val apiKey: String, - val mappedKey: String, - val budget: Double, - val comment: String, - val welcomeMessage: String = "Welcome to our service!" - ) + data class ApiKeyRecord( + val owner: String, + val apiKey: String, + val mappedKey: String, + val budget: Double, + val comment: String, + val welcomeMessage: String = "Welcome to our service!" + ) - override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { - // Log received parameters for debugging + override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { + // Log received parameters for debugging // println("Action: $action, API Key: $apiKey, Mapped Key: $mappedKey, Budget: $budget, Comment: $comment, User: ${user?.email}") - resp.contentType = "text/html" - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return resp.sendError( - HttpServletResponse.SC_UNAUTHORIZED - ) - val action = req.getParameter("action") - val apiKey = req.getParameter("apiKey") - - when (action.lowercase(Locale.ROOT)) { - "edit" -> { - val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } - if (record != null) { - serveEditPage(resp, record) - } else { - resp.writer.write("API Key record not found") - } - } + resp.contentType = "text/html" + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return resp.sendError( + HttpServletResponse.SC_UNAUTHORIZED + ) + val action = req.getParameter("action") + val apiKey = req.getParameter("apiKey") - "delete" -> { // Fix the null safety check consistency - val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } - if (record != null) { - apiKeyRecords.remove(record) - saveRecords() - resp.writer.write("API Key record deleted") - } else { - resp.writer.write("API Key record not found") - } - } + when (action.lowercase(Locale.ROOT)) { + "edit" -> { + val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } + if (record != null) { + serveEditPage(resp, record) + } else { + resp.writer.write("API Key record not found") + } + } - "create" -> { - // Reuse the serveEditPage function but with an empty record for creation - serveEditPage( - resp, - ApiKeyRecord( - user.email, - UUID.randomUUID().toString(), - userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI] - ?: "", // TODO: Expand support for other providers - 0.0, - "" - ) - ) - } + "delete" -> { // Fix the null safety check consistency + val record = apiKeyRecords.find { it.apiKey == apiKey && it.owner == user.email } + if (record != null) { + apiKeyRecords.remove(record) + saveRecords() + resp.writer.write("API Key record deleted") + } else { + resp.writer.write("API Key record not found") + } + } - "invite" -> { - val record = apiKeyRecords.find { it.apiKey == apiKey /*&& it.owner != user.email*/ } - if (record == null) { - throw IllegalArgumentException("API Key record not found, or you do not have permission to access it, or you are the owner.") - } - // Display a confirmation page instead of directly applying the settings - serveInviteConfirmationPage(resp, record, user) - } + "create" -> { + // Reuse the serveEditPage function but with an empty record for creation + serveEditPage( + resp, + ApiKeyRecord( + user.email, + UUID.randomUUID().toString(), + userSettingsManager.getUserSettings(user).apiKeys[APIProvider.OpenAI] + ?: "", // TODO: Expand support for other providers + 0.0, + "" + ) + ) + } - else -> { - resp.writer.write(indexPage(req)) - } + "invite" -> { + val record = apiKeyRecords.find { it.apiKey == apiKey /*&& it.owner != user.email*/ } + if (record == null) { + throw IllegalArgumentException("API Key record not found, or you do not have permission to access it, or you are the owner.") } + // Display a confirmation page instead of directly applying the settings + serveInviteConfirmationPage(resp, record, user) + } + else -> { + resp.writer.write(indexPage(req)) + } } - override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { - val action = req.getParameter("action") - val apiKey = req.getParameter("apiKey") - val mappedKey = req.getParameter("mappedKey") - val budget = req.getParameter("budget")?.toDoubleOrNull() - val comment = req.getParameter("comment") - // welcomeMessage - val welcomeMessage = req.getParameter("welcomeMessage") - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) - val record = apiKeyRecords.find { it.apiKey == apiKey } + } - if (action == "acceptInvite") { - if (apiKey.isNullOrEmpty()) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "API Key is missing") - } else if (user == null) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "User not found") - } else if (record == null) { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid API Key or User not found") - } else { - userSettingsManager.updateUserSettings( - user, userSettingsManager.getUserSettings(user).copy( - apiKeys = mapOf(APIProvider.OpenAI to apiKey), // TODO: Expand support for other providers - apiBase = mapOf(APIProvider.OpenAI to "https://apps.simiacrypt.us/proxy") - ) - ) - resp.sendRedirect("/") // Redirect to a success page or another relevant page - } - } else if (record != null && budget != null && user == null) { // Ensure user is not null before proceeding - apiKeyRecords.remove(record) - apiKeyRecords.add( - record.copy( - mappedKey = mappedKey ?: record.mappedKey, - budget = budget, - comment = comment ?: "" - ) - ) - saveRecords() - resp.sendRedirect("?action=edit&apiKey=$apiKey&editSuccess=true") - } else if (apiKey != null && budget != null) { - // Create a new record if apiKey is not found - val newRecord = ApiKeyRecord( - owner = user?.email ?: "", - apiKey = apiKey, - mappedKey = mappedKey ?: "", - budget = budget, - comment = comment ?: "", - welcomeMessage = welcomeMessage ?: "Welcome to our service!" - ) - apiKeyRecords.add(newRecord) - saveRecords() - resp.sendRedirect( - "?action=edit&apiKey=${ - URLEncoder.encode( - apiKey, - "UTF-8" - ) - }&creationSuccess=true" - ) // Encode apiKey to prevent URL manipulation - } else { - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid input") - } + override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { + val action = req.getParameter("action") + val apiKey = req.getParameter("apiKey") + val mappedKey = req.getParameter("mappedKey") + val budget = req.getParameter("budget")?.toDoubleOrNull() + val comment = req.getParameter("comment") + // welcomeMessage + val welcomeMessage = req.getParameter("welcomeMessage") + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) + val record = apiKeyRecords.find { it.apiKey == apiKey } + + if (action == "acceptInvite") { + if (apiKey.isNullOrEmpty()) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "API Key is missing") + } else if (user == null) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "User not found") + } else if (record == null) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid API Key or User not found") + } else { + userSettingsManager.updateUserSettings( + user, userSettingsManager.getUserSettings(user).copy( + apiKeys = mapOf(APIProvider.OpenAI to apiKey), // TODO: Expand support for other providers + apiBase = mapOf(APIProvider.OpenAI to "https://apps.simiacrypt.us/proxy") + ) + ) + resp.sendRedirect("/") // Redirect to a success page or another relevant page + } + } else if (record != null && budget != null && user == null) { // Ensure user is not null before proceeding + apiKeyRecords.remove(record) + apiKeyRecords.add( + record.copy( + mappedKey = mappedKey ?: record.mappedKey, + budget = budget, + comment = comment ?: "" + ) + ) + saveRecords() + resp.sendRedirect("?action=edit&apiKey=$apiKey&editSuccess=true") + } else if (apiKey != null && budget != null) { + // Create a new record if apiKey is not found + val newRecord = ApiKeyRecord( + owner = user?.email ?: "", + apiKey = apiKey, + mappedKey = mappedKey ?: "", + budget = budget, + comment = comment ?: "", + welcomeMessage = welcomeMessage ?: "Welcome to our service!" + ) + apiKeyRecords.add(newRecord) + saveRecords() + resp.sendRedirect( + "?action=edit&apiKey=${ + URLEncoder.encode( + apiKey, + "UTF-8" + ) + }&creationSuccess=true" + ) // Encode apiKey to prevent URL manipulation + } else { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid input") } + } - private fun indexPage(req: HttpServletRequest): String { - val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return "" - return """ + private fun indexPage(req: HttpServletRequest): String { + val user = ApplicationServices.authenticationManager.getUser(req.getCookie()) ?: return "" + return """ API Key Records @@ -171,21 +171,21 @@ class ApiKeyServlet : HttpServlet() {

API Key Records

${ - apiKeyRecords.filter { it.owner == user.email }.joinToString("\n") { record -> - "" - } - } + apiKeyRecords.filter { it.owner == user.email }.joinToString("\n") { record -> + "" + } + }
Create New API Key Record """.trimIndent() - } + } - private fun serveInviteConfirmationPage(resp: HttpServletResponse, record: ApiKeyRecord, user: User) { - //language=HTML - resp.writer.write( - """ + private fun serveInviteConfirmationPage(resp: HttpServletResponse, record: ApiKeyRecord, user: User) { + //language=HTML + resp.writer.write( + """ Accept API Key Invitation @@ -202,14 +202,14 @@ class ApiKeyServlet : HttpServlet() { """.trimIndent() - ) - } + ) + } - private fun serveEditPage(resp: HttpServletResponse, record: ApiKeyRecord) { - val usageSummary = ApplicationServices.usageManager.getUserUsageSummary(record.apiKey) - //language=HTML - resp.writer.write( - """ + private fun serveEditPage(resp: HttpServletResponse, record: ApiKeyRecord) { + val usageSummary = ApplicationServices.usageManager.getUserUsageSummary(record.apiKey) + //language=HTML + resp.writer.write( + """ Edit API Key Record: ${record.apiKey} @@ -276,16 +276,16 @@ class ApiKeyServlet : HttpServlet() {

Usage Summary

${ - usageSummary.entries.joinToString { (model: OpenAIModel, usage: ApiModel.Usage) -> - """ + usageSummary.entries.joinToString { (model: OpenAIModel, usage: ApiModel.Usage) -> + """

${model.modelName}

total_tokens: ${usage.total_tokens}

Cost: ${usage.cost}

""" - } - } + } + }