From b3397c807c083e634a4414325883e93299dbdb8f Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Tue, 21 Mar 2023 16:16:46 -0700 Subject: [PATCH] 1.0.18 (#34) * 1.0.17 * 1.0.18 --- .gitignore | 1 + CHANGELOG.md | 20 + build.gradle.kts | 5 + gradle.properties | 10 +- .../simiacryptus/aicoder/SoftwareProjectAI.kt | 360 +++++++++++++ .../aicoder/actions/code/DescribeAction.kt | 2 +- .../aicoder/actions/code/DocAction.kt | 6 +- .../aicoder/actions/code/ImplementAction.kt | 2 +- .../actions/code/RenameVariablesAction.kt | 1 + .../actions/code/RewordCommentAction.kt | 2 +- .../actions/dev/GenerateProjectAction.kt | 165 ++++-- .../actions/generic/ChatAppendAction.kt | 64 +++ .../aicoder/actions/generic/InsertAction.kt | 2 +- .../actions/generic/ReplaceOptionsAction.kt | 2 +- .../actions/markdown/MarkdownListAction.kt | 2 +- .../markdown/MarkdownNewTableColAction.kt | 4 +- .../markdown/MarkdownNewTableColsAction.kt | 2 +- .../markdown/MarkdownNewTableRowsAction.kt | 2 +- .../actions/markdown/WikiLinksAction.kt | 2 +- .../aicoder/config/AppSettingsComponent.kt | 3 + .../aicoder/config/AppSettingsConfigurable.kt | 1 - .../aicoder/config/AppSettingsState.kt | 12 +- .../aicoder/openai/async/AsyncAPI.kt | 139 ++---- .../aicoder/openai/async/AsyncAPIImpl.kt | 12 +- .../aicoder/openai/core/ChatRequest.kt | 9 - .../aicoder/openai/core/CoreAPI.kt | 425 ---------------- .../aicoder/openai/proxy/ChatProxy.kt | 87 ---- .../aicoder/openai/proxy/CompletionProxy.kt | 46 -- .../aicoder/openai/proxy/GPTProxyBase.kt | 210 -------- .../aicoder/openai/proxy/SoftwareProjectAI.kt | 76 --- .../openai/translate/TranslationRequest.kt | 2 +- .../translate/TranslationRequest_XML.kt | 2 +- .../openai/ui/CompletionRequestWithModel.kt | 2 +- .../openai/ui/InteractiveCompletionRequest.kt | 2 +- .../openai/ui/InteractiveEditRequest.kt | 2 +- .../{CoreAPIImpl.kt => OpenAIClientImpl.kt} | 6 +- .../aicoder/openai/ui/OpenAI_API.kt | 54 +- .../simiacryptus/aicoder/util/BlockComment.kt | 1 + .../simiacryptus/aicoder/util/IndentedText.kt | 6 +- .../simiacryptus/aicoder/util/LineComment.kt | 6 +- .../simiacryptus/aicoder/util/StyleUtil.kt | 2 +- .../simiacryptus/aicoder/util/UITools.kt | 158 +++++- .../simiacryptus/aicoder/util/psi/PsiUtil.kt | 2 +- .../openai/core => openai}/ApiError.kt | 2 +- .../openai/core => openai}/ChatChoice.kt | 2 +- .../openai/core => openai}/ChatMessage.kt | 2 +- .../github/simiacryptus/openai/ChatRequest.kt | 28 ++ .../openai/core => openai}/ChatResponse.kt | 2 +- .../core => openai}/CompletionChoice.kt | 2 +- .../core => openai}/CompletionRequest.kt | 2 +- .../core => openai}/CompletionResponse.kt | 2 +- .../openai/core => openai}/EditRequest.kt | 2 +- .../{aicoder/openai/core => openai}/Engine.kt | 2 +- .../simiacryptus/openai/HttpClientManager.kt | 138 +++++ .../openai/core => openai}/LogProbs.kt | 2 +- .../core => openai}/ModelMaxException.kt | 2 +- .../core => openai}/ModerationException.kt | 2 +- .../simiacryptus/openai/OpenAIClient.kt | 472 ++++++++++++++++++ .../openai/core => openai}/Response.kt | 2 +- .../{aicoder/openai/core => openai}/Usage.kt | 2 +- .../simiacryptus/openai/proxy/ChatProxy.kt | 109 ++++ .../openai/proxy/CompletionProxy.kt | 47 ++ .../simiacryptus/openai/proxy/Description.kt | 4 + .../simiacryptus/openai/proxy/GPTProxyBase.kt | 240 +++++++++ .../openai/proxy/ValidatedObject.kt | 26 + .../{aicoder => }/util/StringTools.kt | 2 +- src/main/resources/META-INF/plugin.xml | 10 + .../simiacryptus/aicoder/ProxyPlay.ws.kts | 22 - .../proxy/AlternateHistorySimulator.kt | 91 ++++ .../simiacryptus/aicoder/proxy/AutoDevelop.kt | 256 +++++++--- .../simiacryptus/aicoder/proxy/AutoNews.kt | 155 +++--- .../aicoder/proxy/ChildrensStory.kt | 166 ++++++ .../simiacryptus/aicoder/proxy/ComicBook.kt | 4 +- .../{DebateJudge.kt => DebateSimulator.kt} | 129 ++--- .../aicoder/proxy/FamilyGuyWriter.kt | 146 ++++++ .../aicoder/proxy/GenerationReportBase.kt | 9 +- .../simiacryptus/aicoder/proxy/ImageTest.kt | 2 +- .../proxy/InternationalEventsSimulator.kt | 73 +++ .../simiacryptus/aicoder/proxy/ProxyTest.kt | 27 +- .../simiacryptus/aicoder/proxy/RecipeBook.kt | 116 +++++ .../simiacryptus/aicoder/proxy/TravelGuide.kt | 60 +++ .../simiacryptus/aicoder/proxy/VideoGame.kt | 4 +- 82 files changed, 2957 insertions(+), 1326 deletions(-) create mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ChatAppendAction.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatRequest.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CoreAPI.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/ChatProxy.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/CompletionProxy.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/GPTProxyBase.kt delete mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/SoftwareProjectAI.kt rename src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/{CoreAPIImpl.kt => OpenAIClientImpl.kt} (77%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ApiError.kt (86%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ChatChoice.kt (76%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ChatMessage.kt (81%) create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/ChatRequest.kt rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ChatResponse.kt (90%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/CompletionChoice.kt (87%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/CompletionRequest.kt (95%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/CompletionResponse.kt (91%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/EditRequest.kt (95%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/Engine.kt (90%) create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/LogProbs.kt (90%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ModelMaxException.kt (80%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/ModerationException.kt (54%) create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/Response.kt (83%) rename src/main/kotlin/com/github/simiacryptus/{aicoder/openai/core => openai}/Usage.kt (76%) create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt rename src/main/kotlin/com/github/simiacryptus/{aicoder => }/util/StringTools.kt (97%) delete mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/ProxyPlay.ws.kts create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AlternateHistorySimulator.kt create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt rename src/test/kotlin/com/github/simiacryptus/aicoder/proxy/{DebateJudge.kt => DebateSimulator.kt} (61%) create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/InternationalEventsSimulator.kt create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt create mode 100644 src/test/kotlin/com/github/simiacryptus/aicoder/proxy/TravelGuide.kt diff --git a/.gitignore b/.gitignore index 2a8b33f6..529c5f24 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ chain.crt api.log.json *.log.java *.log +api.* diff --git a/CHANGELOG.md b/CHANGELOG.md index c36bbd11..558ddc8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,26 @@ ### Added - +## [1.0.18] + +### Improved +- API stability and performance +- Various bug fixes +- Max tokens handling + +## [1.0.17] + +### Added +- Ability to develop entire software projects from scratch (not a joke) +- Support for self-aware artificial intelligence (joke) + +### Removed +- Human value and significance (joke) +- Barriers to information warfare (not a joke) + +### Fixed +- Human nature (joke) + ## [1.0.16] ### Added diff --git a/build.gradle.kts b/build.gradle.kts index 15632859..3133d7b8 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -74,6 +74,11 @@ kover.xmlReport { } tasks { + compileKotlin { + kotlinOptions { + javaParameters = true + } + } compileTestKotlin { kotlinOptions { javaParameters = true diff --git a/gradle.properties b/gradle.properties index 203cbae6..0cbaaee5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,7 +4,7 @@ pluginGroup = com.github.simiacryptus pluginName = intellij-aicoder pluginRepositoryUrl = https://github.com/SimiaCryptus/intellij-aicoder # SemVer format -> https://semver.org -pluginVersion = 1.0.16 +pluginVersion = 1.0.18 # Supported build number ranges and IntelliJ Platform versions -> https://plugins.jetbrains.com/docs/intellij/build-number-ranges.html pluginSinceBuild = 203 @@ -22,8 +22,8 @@ pluginUntilBuild = 231.* # IntelliJ IDEA Ultimate platformType = IU -platformPlugins = com.intellij.java, org.intellij.scala:2021.3.22, Pythonid:213.7172.26, org.jetbrains.plugins.go:213.7172.6 -#platformPlugins = com.intellij.java, org.intellij.scala:2022.3.16, Pythonid:223.8214.52, org.jetbrains.plugins.go:223.8214.52 +#platformPlugins = com.intellij.java, org.intellij.scala:2021.3.22, Pythonid:213.7172.26, org.jetbrains.plugins.go:213.7172.6 +platformPlugins = com.intellij.java, org.intellij.scala:2022.3.16, Pythonid:223.8214.52, org.jetbrains.plugins.go:223.8214.52 #platformPlugins = com.intellij.java # PhpStorm @@ -35,8 +35,8 @@ platformPlugins = com.intellij.java, org.intellij.scala:2021.3.22, Pythonid:213. #platformPlugins = JavaScript # https://mvnrepository.com/artifact/com.jetbrains.intellij.idea/ideaIU -platformVersion = 2021.3.3 -#platformVersion = 2022.3.1 +#platformVersion = 2021.3.3 +platformVersion = 2022.3.1 # Gradle Releases -> https://github.com/gradle/gradle/releases gradleVersion = 7.6.1 diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt new file mode 100644 index 00000000..a93a5cac --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt @@ -0,0 +1,360 @@ +package com.github.simiacryptus.aicoder + +import com.github.simiacryptus.openai.proxy.Description +import com.github.simiacryptus.openai.proxy.ValidatedObject +import java.io.File +import java.util.* +import java.util.concurrent.Callable +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.Future +import java.util.concurrent.atomic.AtomicInteger +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream + +interface SoftwareProjectAI { + + fun newProject(description: String): Project + + data class Project( + val name: String? = "", + val description: String? = "", + val language: String? = "", + val features: List? = listOf(), + val libraries: List? = listOf(), + val buildTools: List? = listOf(), + ) : ValidatedObject + + fun getProjectStatements(description: String, project: Project): ProjectStatements + + data class ProjectStatements( + val assumptions: List? = listOf(), + val designPatterns: List? = listOf(), + val requirements: List? = listOf(), + val risks: List? = listOf(), + ) : ValidatedObject + + fun buildProjectDesign(project: Project, requirements: ProjectStatements): ProjectDesign + + data class ProjectDesign( + @Description("Major components e.g. 'core', 'ui', 'api', etc.") + val components: List? = listOf(), + @Description("Documentation files e.g. README.md, LICENSE, etc.") + val documents: List? = listOf(), + @Description("Individual test cases") + val tests: List? = listOf(), + ) : ValidatedObject + + data class ComponentDetails( + val name: String? = "", + val description: String? = "", + val features: List? = listOf(), + ) : ValidatedObject + + data class TestDetails( + val name: String? = "", + val steps: List? = listOf(), + val expectations: List? = listOf(), + ) : ValidatedObject + + data class DocumentationDetails( + val name: String? = "", + val description: String? = "", + val sections: List? = listOf(), + ) : ValidatedObject + + fun buildProjectFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: ProjectDesign, + recursive: Boolean = true + ): List + + fun buildComponentFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: ComponentDetails, + recursive: Boolean = true + ): List + + fun buildTestFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: TestDetails, + recursive: Boolean = true + ): List + + fun buildDocumentationFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: DocumentationDetails, + recursive: Boolean = true + ): List + + data class CodeSpecification( + val description: String? = "", + val requires: List? = listOf(), + val publicProperties: List? = listOf(), + val publicMethodSignatures: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class DocumentSpecification( + val description: String? = "", + val requires: List? = listOf(), + val sections: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class TestSpecification( + val description: String? = "", + val requires: List? = listOf(), + val steps: List? = listOf(), + val expectations: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class FilePath( + @Description("File name relative to project root, e.g. src/main/java/Foo.java") + val file: String? = "", + ) : ValidatedObject { + override fun toString(): String { + return file ?: "" + } + + override fun validate(): Boolean { + if (file?.isBlank() != false) return false + return super.validate() + } + } + + fun implementComponentSpecification( + project: Project, + component: ComponentDetails, + imports: List, + specification: CodeSpecification, + ): SourceCode + + + fun implementTestSpecification( + project: Project, + specification: TestSpecification, + test: TestDetails, + imports: List, + specificationAgain: TestSpecification, + ): SourceCode + + + fun implementDocumentationSpecification( + project: Project, + specification: DocumentSpecification, + documentation: DocumentationDetails, + imports: List, + specificationAgain: DocumentSpecification, + ): SourceCode + + data class SourceCode( + @Description("language of the code, e.g. \"java\" or \"kotlin\"") + val language: String? = "", + @Description("Fully implemented source code") + val code: String? = "", + ) : ValidatedObject + + companion object { + val log = org.slf4j.LoggerFactory.getLogger(SoftwareProjectAI::class.java) + fun parallelImplement( + api: SoftwareProjectAI, + project: Project, + components: Map>?, + documents: Map>?, + tests: Map>?, + drafts: Int, + threads: Int + ): Map = parallelImplementWithAlternates( + api, + project, + components ?: mapOf(), + documents ?: mapOf(), + tests ?: mapOf(), + drafts, + threads + ).mapValues { it.value.maxByOrNull { it.code?.length ?: 0 } } + + fun parallelImplementWithAlternates( + api: SoftwareProjectAI, + project: Project, + components: Map>, + documents: Map>, + tests: Map>, + drafts: Int, + threads: Int, + progress: (Double) -> Unit = {} + ): Map> { + val threadPool = Executors.newFixedThreadPool(threads) + try { + val totalDrafts = (components + tests + documents).values.sumOf { it.size } * drafts + val currentDraft = AtomicInteger(0) + val fileImplCache = ConcurrentHashMap>>>() + val normalizeFileName: (String?) -> String = { + it?.trimStart('/', '.') ?: "" + } + + // Build Components + fun buildCodeSpec( + component: ComponentDetails, + files: List, + file: CodeSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementComponentSpecification( + project, + component, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildComponentDetails( + component: ComponentDetails, + files: List + ): List>> { + return files.flatMap(fun(file: CodeSpecification): List>> { + return buildCodeSpec(component, files, file) + }).toTypedArray().toList() + } + + val componentFutures = components.flatMap { (component, files) -> + buildComponentDetails(component, files) + }.toTypedArray() + + // Build Documents + fun buildDocumentSpec( + documentation: DocumentationDetails, + files: List, + file: DocumentSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementDocumentationSpecification( + project, + file.copy(requires = listOf()), + documentation, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildDocumentDetails( + documentation: DocumentationDetails, + files: List + ): List>> { + return files.flatMap(fun(file: DocumentSpecification): List>> { + return buildDocumentSpec(documentation, files, file) + }).toTypedArray().toList() + } + + val documentFutures = documents.flatMap { (documentation, files) -> + buildDocumentDetails(documentation, files) + }.toTypedArray() + + // Build Tests + fun buildTestSpec( + test: TestDetails, + files: List, + file: TestSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementTestSpecification( + project, + file.copy(requires = listOf()), + test, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildTestDetails( + test: TestDetails, + files: List + ): List>> { + return files.flatMap(fun(file: TestSpecification): List>> { + return buildTestSpec(test, files, file) + }).toTypedArray().toList() + } + + val testFutures = tests.flatMap { (test, files) -> + buildTestDetails(test, files) + }.toTypedArray() + + return (getAll(componentFutures) + getAll(documentFutures) + getAll(testFutures)).mapValues { + it.value.map { it.second }.sortedBy { it.code?.length ?: 0 } + } + } finally { + threadPool.shutdown() + } + } + + private fun getAll(testFutures: Array>>) = + testFutures.map { + try { + Optional.ofNullable(it.get()) + } catch (e: Throwable) { + Optional.empty() + } + }.filter { !it.isEmpty }.map { it.get() }.groupBy { it.first } + + fun write( + zipArchiveFile: File, + implementations: Map + ) { + ZipOutputStream(zipArchiveFile.outputStream()).use { zip -> + implementations.forEach { (file, sourceCodes) -> + zip.putNextEntry(ZipEntry(file.toString())) + zip.write(sourceCodes!!.code?.toByteArray() ?: byteArrayOf()) + zip.closeEntry() + } + } + } + } + + +} + diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DescribeAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DescribeAction.kt index ed2aed76..b2fba639 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DescribeAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DescribeAction.kt @@ -3,7 +3,7 @@ package com.github.simiacryptus.aicoder.actions.code import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage import com.github.simiacryptus.aicoder.util.IndentedText -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DocAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DocAction.kt index a3635720..310cab64 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DocAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/DocAction.kt @@ -3,7 +3,7 @@ package com.github.simiacryptus.aicoder.actions.code import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage import com.github.simiacryptus.aicoder.util.IndentedText -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.UITools.getInstruction import com.github.simiacryptus.aicoder.util.UITools.redoableRequest @@ -62,7 +62,9 @@ class DocAction : AnAction() { }, { docString -> replaceString( - document, startOffset, endOffset, + document, + startOffset, + endOffset, docString ) } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/ImplementAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/ImplementAction.kt index 5be51131..c9321148 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/ImplementAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/ImplementAction.kt @@ -3,7 +3,7 @@ package com.github.simiacryptus.aicoder.actions.code import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage import com.github.simiacryptus.aicoder.util.IndentedText -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.UITools.redoableRequest import com.github.simiacryptus.aicoder.util.UITools.replaceString diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RenameVariablesAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RenameVariablesAction.kt index affe4144..7f82cc96 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RenameVariablesAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RenameVariablesAction.kt @@ -5,6 +5,7 @@ import com.github.simiacryptus.aicoder.util.* import com.github.simiacryptus.aicoder.util.UITools.replaceString import com.github.simiacryptus.aicoder.util.UITools.showCheckboxDialog import com.github.simiacryptus.aicoder.util.psi.PsiUtil +import com.github.simiacryptus.util.StringTools import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RewordCommentAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RewordCommentAction.kt index 26465a98..3430a22d 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RewordCommentAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/code/RewordCommentAction.kt @@ -2,7 +2,7 @@ package com.github.simiacryptus.aicoder.actions.code import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.psi.PsiUtil import com.intellij.openapi.actionSystem.AnAction diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt index 4945899e..83a445e9 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt @@ -1,54 +1,161 @@ package com.github.simiacryptus.aicoder.actions.dev import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.proxy.ChatProxy -import com.github.simiacryptus.aicoder.openai.proxy.SoftwareProjectAI +import com.github.simiacryptus.aicoder.config.Name import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.aicoder.SoftwareProjectAI +import com.github.simiacryptus.openai.proxy.ChatProxy import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import java.io.File +import javax.swing.JCheckBox import javax.swing.JTextArea +import javax.swing.JTextField class GenerateProjectAction : AnAction() { - fun createProjectFiles(e: AnActionEvent, description: String) { - val outputDir = File(UITools.getSelectedFolder(e)!!.canonicalPath) - val api = ChatProxy(apiKey = AppSettingsState.instance.apiKey, base = AppSettingsState.instance.apiBase).create( - SoftwareProjectAI::class.java - ) - val project = api.newProject(description) - val requirements = api.getProjectStatements(project) - val projectDesign = api.buildProjectDesign(project, requirements) - val files = api.buildProjectFileSpecifications(project, requirements, projectDesign) - for (file in files.files) { - val sourceCode = api.implement( - project, - files.files.map { it.location }.filter { file.requires.contains(it) }.toList(), - file - ) - val outFile = - outputDir.resolve(file.location.path.replace('\\','/').trimEnd('/') + "/${file.location.name}.${file.location.extension}") - outFile.parentFile.mkdirs() - outFile.writeText(sourceCode.code) - log.warn("Wrote ${outFile.canonicalPath}") - } - - } - override fun update(e: AnActionEvent) { e.presentation.isEnabledAndVisible = isEnabled(e) super.update(e) } - data class SettingsUI(val description: JTextArea = JTextArea()) - data class Settings(var description: String = "") + @Suppress("UNUSED") + class SettingsUI { + @Name("Project Description") + val description: JTextArea = JTextArea() + @Name("Drafts Per File") + val drafts: JTextField = JTextField("2") + val saveAlternates: JCheckBox = JCheckBox("Save Alternates") + } + + data class Settings( + var description: String = "", + var drafts: Int = 2, + var saveAlternates: Boolean = false + ) override fun actionPerformed(e: AnActionEvent) { UITools.showDialog(e, SettingsUI::class.java, Settings::class.java) { config -> - createProjectFiles(e, config.description) + handleImplement(e, config) } } + private fun handleImplement( + e: AnActionEvent, + config: Settings + ) = Thread { + val selectedFolder = UITools.getSelectedFolder(e)!! + val api = ChatProxy( + apiKey = AppSettingsState.instance.apiKey, + base = AppSettingsState.instance.apiBase, + logLevel = AppSettingsState.instance.apiLogLevel, + maxTokens = AppSettingsState.instance.maxTokens, + ).create( + SoftwareProjectAI::class.java + ) + val project = UITools.run( + e.project, "Parsing Request", true + ) { + val newProject = api.newProject(""" + ${config.description} + """.trimIndent().trim()) + if (it.isCanceled) throw InterruptedException() + newProject + } + val requirements = UITools.run( + e.project, "Specifying Project", true + ) { + val projectStatements = api.getProjectStatements(config.description, project) + if (it.isCanceled) throw InterruptedException() + projectStatements + } + val projectDesign = UITools.run( + e.project, "Designing Project", true + ) { + val buildProjectDesign = api.buildProjectDesign(project, requirements) + if (it.isCanceled) throw InterruptedException() + buildProjectDesign + } + val files = UITools.run( + e.project, "Specifying Files", true + ) { + val buildProjectFileSpecifications = + api.buildProjectFileSpecifications(project, requirements, projectDesign) + if (it.isCanceled) throw InterruptedException() + buildProjectFileSpecifications + } + + val components = + UITools.run( + e.project, "Specifying Components", true + ) { + projectDesign.components?.map { it to api.buildComponentFileSpecifications(project, requirements, it) } + ?.toMap() + } + + val documents = + UITools.run( + e.project, "Specifying Documents", true + ) { + projectDesign.documents?.map { + it to api.buildDocumentationFileSpecifications( + project, + requirements, + it + ) + }?.toMap() + } + + val tests = UITools.run( + e.project, "Specifying Tests", true + ) { projectDesign.tests?.map { it to api.buildTestFileSpecifications(project, requirements, it) }?.toMap() } + + val sourceCodeMap = UITools.run( + e.project, "Implementing Files", true + ) { + SoftwareProjectAI.parallelImplementWithAlternates( + api, + project, + components ?: emptyMap(), + documents ?: emptyMap(), + tests ?: emptyMap(), + config.drafts, + AppSettingsState.instance.apiThreads + ) { progress -> + if (it.isCanceled) throw InterruptedException() + it.fraction = progress + } + } + UITools.run(e.project, "Writing Files", false) { + val outputDir = File(selectedFolder.canonicalPath!!) + sourceCodeMap.forEach { (file, sourceCode) -> + val relative = file.file + ?.trimEnd('/') + ?.trimStart('/', '.') ?: "" + if (File(relative).isRooted) { + log.warn("Invalid path: $relative") + } else { + val outFile = outputDir.resolve(relative) + outFile.parentFile.mkdirs() + val best = sourceCode.maxByOrNull { it.code?.length ?: 0 }!! + outFile.writeText(best?.code ?: "") + log.debug("Wrote ${outFile.canonicalPath} (Resolved from $relative)") + if (config.saveAlternates) + for ((index, alternate) in sourceCode.filter { it != best }.withIndex()) { + val outFileAlternate = + outputDir.resolve( + relative + ".${index + 1}" + ) + outFileAlternate.parentFile.mkdirs() + outFileAlternate.writeText(alternate?.code ?: "") + log.debug("Wrote ${outFileAlternate.canonicalPath} (Resolved from $relative)") + } + } + } + selectedFolder.refresh(false, true) + } + }.start() + private fun isEnabled(e: AnActionEvent): Boolean { if (UITools.isSanctioned()) return false if (!AppSettingsState.instance.devActions) return false diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ChatAppendAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ChatAppendAction.kt new file mode 100644 index 00000000..fc12a418 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ChatAppendAction.kt @@ -0,0 +1,64 @@ +package com.github.simiacryptus.aicoder.actions.generic + +import com.github.simiacryptus.aicoder.config.AppSettingsState +import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.aicoder.util.UITools.hasSelection +import com.github.simiacryptus.aicoder.util.UITools.insertString +import com.github.simiacryptus.aicoder.util.UITools.redoableRequest +import com.github.simiacryptus.openai.ChatMessage +import com.intellij.openapi.actionSystem.AnAction +import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.actionSystem.CommonDataKeys +import java.util.* + +/** + * The GenericAppend IntelliJ action allows users to quickly append a prompt to the end of a selected text. + * To use, select some text and then select the GenericAppend action from the editor context menu. + * The action will insert the completion at the end of the selected text. + */ +class ChatAppendAction : AnAction() { + override fun update(e: AnActionEvent) { + e.presentation.isEnabledAndVisible = isEnabled(e) + super.update(e) + } + + override fun actionPerformed(event: AnActionEvent) { + val caret = event.getData(CommonDataKeys.CARET) + val before: CharSequence? = Objects.requireNonNull(caret)!!.selectedText + val settings = AppSettingsState.instance + val request = settings.createChatRequest() //.appendPrompt(before ?: "") + request.messages = arrayOf( + ChatMessage( + ChatMessage.Role.system, + "Append text to the end of the user's prompt" + ), + ChatMessage( + ChatMessage.Role.user, + before.toString() + ) + ) + val document = event.getRequiredData(CommonDataKeys.EDITOR).document + val selectionEnd = caret!!.selectionEnd + redoableRequest( + request, "", event + ) { newText: CharSequence? -> + insertString( + document, selectionEnd, + newText!! + ) + } + } + + companion object { + @Suppress("unused") + private fun isEnabled(e: AnActionEvent): Boolean { + if (UITools.isSanctioned()) return false + return hasSelection(e) + } + } +} + + + + + diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/InsertAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/InsertAction.kt index a96f73de..1d942618 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/InsertAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/InsertAction.kt @@ -1,7 +1,7 @@ package com.github.simiacryptus.aicoder.actions.generic import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ReplaceOptionsAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ReplaceOptionsAction.kt index 360cbad8..df9d02cd 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ReplaceOptionsAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/generic/ReplaceOptionsAction.kt @@ -1,7 +1,7 @@ package com.github.simiacryptus.aicoder.actions.generic import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.UITools.showRadioButtonDialog import com.intellij.openapi.actionSystem.AnAction diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownListAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownListAction.kt index ca3a18a4..914fb266 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownListAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownListAction.kt @@ -2,7 +2,7 @@ package com.github.simiacryptus.aicoder.actions.markdown import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.UITools.getIndent import com.github.simiacryptus.aicoder.util.UITools.getInstruction diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColAction.kt index de5baef7..a52c10d4 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColAction.kt @@ -1,11 +1,11 @@ package com.github.simiacryptus.aicoder.actions.markdown import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.psi.PsiUtil +import com.github.simiacryptus.openai.CompletionRequest import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColsAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColsAction.kt index 8cbc0fe9..bad0ef2a 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColsAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableColsAction.kt @@ -2,7 +2,7 @@ package com.github.simiacryptus.aicoder.actions.markdown import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.psi.PsiUtil import com.intellij.openapi.actionSystem.AnAction diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableRowsAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableRowsAction.kt index afe575ee..cdfc9d10 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableRowsAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/MarkdownNewTableRowsAction.kt @@ -2,7 +2,7 @@ package com.github.simiacryptus.aicoder.actions.markdown import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.psi.PsiUtil import com.intellij.openapi.actionSystem.AnAction diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/WikiLinksAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/WikiLinksAction.kt index 696b0cc1..985e1bbd 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/WikiLinksAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/markdown/WikiLinksAction.kt @@ -3,7 +3,7 @@ package com.github.simiacryptus.aicoder.actions.markdown import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API import com.github.simiacryptus.aicoder.util.ComputerLanguage -import com.github.simiacryptus.aicoder.util.StringTools.replaceAllNonOverlapping +import com.github.simiacryptus.util.StringTools.replaceAllNonOverlapping import com.github.simiacryptus.aicoder.util.UITools import com.github.simiacryptus.aicoder.util.UITools.getInstruction import com.github.simiacryptus.aicoder.util.UITools.redoableRequest diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsComponent.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsComponent.kt index 066f42cf..1bcbe417 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsComponent.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsComponent.kt @@ -86,6 +86,9 @@ class AppSettingsComponent { @Name("Edit Model") val model_edit = OpenAI_API.modelSelector + @Name("Chat Model") + val model_chat = OpenAI_API.modelSelector + @Name("API Threads") val apiThreads = JBTextField() diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsConfigurable.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsConfigurable.kt index ec55d73d..2d01dd4e 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsConfigurable.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsConfigurable.kt @@ -2,7 +2,6 @@ package com.github.simiacryptus.aicoder.config import com.github.simiacryptus.aicoder.util.UITools import com.intellij.openapi.options.Configurable -import com.intellij.util.ui.FormBuilder import org.jetbrains.annotations.Nls import java.util.* import javax.swing.JComponent diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt index 3f068cdd..dc7e15b2 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt @@ -1,9 +1,10 @@ package com.github.simiacryptus.aicoder.config -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest -import com.github.simiacryptus.aicoder.openai.core.EditRequest import com.github.simiacryptus.aicoder.openai.translate.TranslationRequest import com.github.simiacryptus.aicoder.openai.translate.TranslationRequestTemplate +import com.github.simiacryptus.openai.ChatRequest +import com.github.simiacryptus.openai.CompletionRequest +import com.github.simiacryptus.openai.EditRequest import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.PersistentStateComponent import com.intellij.openapi.components.State @@ -31,6 +32,7 @@ class AppSettingsState : PersistentStateComponent { var apiKey = "" var model_completion = "text-davinci-003" var model_edit = "text-davinci-edit-001" + var model_chat = "gpt-3.5-turbo-0301" var maxTokens = 1000 var temperature = 0.1 var style = "" @@ -59,6 +61,10 @@ class AppSettingsState : PersistentStateComponent { return EditRequest(this) } + fun createChatRequest(): ChatRequest { + return ChatRequest(this) + } + override fun getState(): AppSettingsState { return this } @@ -79,6 +85,7 @@ class AppSettingsState : PersistentStateComponent { if (apiKey != that.apiKey) return false if (model_completion != that.model_completion) return false if (model_edit != that.model_edit) return false + if (model_chat != that.model_chat) return false if (translationRequestTemplate != that.translationRequestTemplate) return false if (apiLogLevel != that.apiLogLevel) return false if (devActions != that.devActions) return false @@ -91,6 +98,7 @@ class AppSettingsState : PersistentStateComponent { apiKey, model_completion, model_edit, + model_chat, maxTokens, temperature, translationRequestTemplate, diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt index 6bb3954d..4983183e 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt @@ -1,16 +1,10 @@ package com.github.simiacryptus.aicoder.openai.async -import com.fasterxml.jackson.annotation.JsonInclude -import com.fasterxml.jackson.databind.MapperFeature -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.databind.SerializationFeature import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest -import com.github.simiacryptus.aicoder.openai.core.CompletionResponse -import com.github.simiacryptus.aicoder.openai.core.CoreAPI -import com.github.simiacryptus.aicoder.openai.core.EditRequest -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.aicoder.util.UITools.run +import com.github.simiacryptus.openai.* import com.google.common.util.concurrent.* import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressIndicator @@ -28,7 +22,7 @@ import java.util.function.Consumer import java.util.stream.Collectors open class AsyncAPI( - val coreAPI: CoreAPI, + val openAIClient: OpenAIClient, private val suppressProgress: Boolean = false ) { @@ -49,33 +43,7 @@ open class AsyncAPI( ) { override fun compute(indicator: ProgressIndicator): CompletionResponse { try { - if (editRequest.input == null) { - log( - settings.apiLogLevel, String.format( - "Text Edit Request\nInstruction:\n\t%s\n", - editRequest.instruction.replace("\n", "\n\t") - ) - ) - } else { - log( - settings.apiLogLevel, String.format( - "Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n", - editRequest.instruction.replace("\n", "\n\t"), - editRequest.input!!.replace("\n", "\n\t") - ) - ) - } - val request: String = - StringTools.restrictCharacterSet( - mapper.writeValueAsString(editRequest), - allowedCharset - ) - val result = coreAPI.post(settings.apiBase + "/edits", request) - val completionResponse = coreAPI.processCompletionResponse(result) - coreAPI.logComplete( - completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' } - ) - return completionResponse + return openAIClient.edit(editRequest) } catch (e: IOException) { throw RuntimeException(e) } catch (e: InterruptedException) { @@ -111,7 +79,7 @@ open class AsyncAPI( ) ) { _: Any? -> run( - object : Task.WithResult( + task = object : Task.WithResult( project, "Text Completion", canBeCancelled @@ -129,7 +97,7 @@ open class AsyncAPI( ) threadRef.getAndSet(Thread.currentThread()) try { - return coreAPI.complete(completionRequest, model) + return openAIClient.complete(completionRequest, model) } catch (e: IOException) { log.error(e) throw RuntimeException(e) @@ -163,10 +131,10 @@ open class AsyncAPI( private fun moderateAsync(project: Project?, text: String): ListenableFuture<*> { return run( - object : Task.WithResult, Exception?>(project, "Moderation", false) { + task = object : Task.WithResult, Exception?>(project, "Moderation", false) { override fun compute(indicator: ProgressIndicator): ListenableFuture<*> { return pool.submit { - coreAPI.moderate(text) + openAIClient.moderate(text) } } }, @@ -174,47 +142,37 @@ open class AsyncAPI( ) } - fun run(task: Task.WithResult, retries: Int): T { - return try { - if (!suppressProgress) { - ProgressManager.getInstance().run(task) - } else { - task.run(AbstractProgressIndicatorBase()) - task.result - } - } catch (e: RuntimeException) { - if (isInterruptedException(e)) throw e - if (retries > 0) { - log.warn("Retrying request", e) - run(task, retries - 1) - } else { - throw e - } - } catch (e: InterruptedException) { - throw RuntimeException(e) - } catch (e: Exception) { - if (isInterruptedException(e)) throw RuntimeException(e) - if (retries > 0) { - log.warn("Retrying request", e) - try { - Thread.sleep(15000) - } catch (ex: InterruptedException) { - Thread.currentThread().interrupt() - } - run(task, retries - 1) - } else { - throw RuntimeException(e) - } + fun chat(project: Project?, newRequest: ChatRequest, settings: AppSettingsState?): ListenableFuture { + return map( + moderateAsync( + project, + StringTools.restrictCharacterSet(newRequest.messages.map { "${it.role?.name ?: "?"}: ${it.content}" }.joinToString { "\n" }, allowedCharset) + ) + ) { _: Any? -> + run( + task = object : Task.WithResult( + project, + "Chat", + true + ) { + override fun compute(indicator: ProgressIndicator): ChatResponse { + try { + newRequest.max_tokens = settings!!.maxTokens + newRequest.temperature = settings.temperature + newRequest.model = settings.model_chat + return openAIClient.chat(newRequest) + } catch (e: IOException) { + throw RuntimeException(e) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + } + }, + 3 + ) } } - private fun isInterruptedException(e: Throwable?): Boolean { - if (e is InterruptedException) return true - return if (e!!.cause != null && e.cause !== e) isInterruptedException( - e.cause - ) else false - } - companion object { private val apiThreads = AppSettingsState.instance.apiThreads val log = Logger.getInstance(AsyncAPI::class.java) @@ -222,10 +180,10 @@ open class AsyncAPI( fun log(level: LogLevel, msg: String) { val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") when (level) { - LogLevel.Error -> CoreAPI.log.error(message) - LogLevel.Warn -> CoreAPI.log.warn(message) - LogLevel.Info -> CoreAPI.log.info(message) - else -> CoreAPI.log.debug(message) + LogLevel.Error -> OpenAIClient.log.error(message) + LogLevel.Warn -> OpenAIClient.log.warn(message) + LogLevel.Info -> OpenAIClient.log.info(message) + else -> OpenAIClient.log.debug(message) } } @@ -234,20 +192,10 @@ open class AsyncAPI( o: com.google.common.base.Function ): ListenableFuture = Futures.transform(moderateAsync, o, pool) - val mapper: ObjectMapper - get() { - val mapper = ObjectMapper() - mapper - .enable(SerializationFeature.INDENT_OUTPUT) - .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) - .enable(MapperFeature.USE_STD_BEAN_NAMING) - .setSerializationInclusion(JsonInclude.Include.NON_NULL) - .activateDefaultTyping(mapper.polymorphicTypeValidator) - return mapper - } - val threadFactory: ThreadFactory = ThreadFactoryBuilder().setNameFormat("API Thread %d").build() + + val threadFactory: ThreadFactory = ThreadFactoryBuilder().setNameFormat("API Thread %d").build() val pool: ListeningExecutorService = MoreExecutors.listeningDecorator( ThreadPoolExecutor( apiThreads, @@ -272,6 +220,7 @@ open class AsyncAPI( } }, pool) } + } val allowedCharset: Charset = Charset.forName("ASCII") diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt index c9380f22..1822d51d 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt @@ -2,13 +2,12 @@ package com.github.simiacryptus.aicoder.openai.async import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API -import com.github.simiacryptus.aicoder.openai.core.CoreAPI +import com.github.simiacryptus.openai.OpenAIClient import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressIndicator import com.jetbrains.rd.util.AtomicReference -import java.io.IOException -class AsyncAPIImpl(core: CoreAPI, appSettingsState: AppSettingsState) : AsyncAPI( +class AsyncAPIImpl(core: OpenAIClient, appSettingsState: AppSettingsState) : AsyncAPI( core, appSettingsState.suppressProgress ) { @@ -18,15 +17,12 @@ class AsyncAPIImpl(core: CoreAPI, appSettingsState: AppSettingsState) : AsyncAPI val thread = threadRef.get() if (null != thread) { thread.interrupt() - try { - coreAPI.clients[thread]!!.close() - } catch (e: IOException) { - log.warn("Error closing client: " + e.message) - } + openAIClient.closeClient(thread) } } } + companion object { @JvmStatic private val log = Logger.getInstance(OpenAI_API::class.java) diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatRequest.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatRequest.kt deleted file mode 100644 index f15f56fa..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatRequest.kt +++ /dev/null @@ -1,9 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.core - -class ChatRequest @Suppress("unused") constructor() { - var messages = arrayOf() - var model: String? = null - var temperature = 0.0 - var max_tokens = 0 - var stop: Array? = null -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CoreAPI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CoreAPI.kt deleted file mode 100644 index 04e1a1f4..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CoreAPI.kt +++ /dev/null @@ -1,425 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.core - -import com.fasterxml.jackson.annotation.JsonInclude -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.databind.MapperFeature -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.databind.SerializationFeature -import com.fasterxml.jackson.databind.node.ObjectNode -import com.github.simiacryptus.aicoder.openai.async.AsyncAPI -import com.github.simiacryptus.aicoder.openai.core.* -import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API -import com.github.simiacryptus.aicoder.util.StringTools -import com.github.simiacryptus.aicoder.util.UITools -import com.google.gson.Gson -import com.google.gson.JsonObject -import com.intellij.openapi.diagnostic.Logger -import com.jetbrains.rd.util.LogLevel -import org.apache.http.HttpResponse -import org.apache.http.client.methods.HttpGet -import org.apache.http.client.methods.HttpPost -import org.apache.http.client.methods.HttpRequestBase -import org.apache.http.entity.ContentType -import org.apache.http.entity.StringEntity -import org.apache.http.entity.mime.HttpMultipartMode -import org.apache.http.entity.mime.MultipartEntityBuilder -import org.apache.http.impl.client.CloseableHttpClient -import org.apache.http.impl.client.HttpClientBuilder -import org.apache.http.util.EntityUtils -import java.awt.image.BufferedImage -import java.io.IOException -import java.net.URL -import java.nio.charset.Charset -import java.util.* -import java.util.Map -import java.util.concurrent.ConcurrentHashMap -import java.util.regex.Pattern -import javax.imageio.ImageIO -import kotlin.collections.MutableMap -import kotlin.collections.first -import kotlin.collections.joinToString -import kotlin.collections.map -import kotlin.collections.set - -open class CoreAPI( - val apiBase: String, - var key: String, - val logLevel: LogLevel -) { - fun getEngines(): Array { - val engines = mapper.readValue( - get(OpenAI_API.settingsState!!.apiBase + "/engines"), - ObjectNode::class.java - ) - val data = engines["data"] - val items = - arrayOfNulls(data.size()) - for (i in 0 until data.size()) { - items[i] = data[i]["id"].asText() - } - Arrays.sort(items) - return items - } - - val clients: MutableMap = ConcurrentHashMap() - - @Throws(IOException::class, InterruptedException::class) - fun post(url: String, body: String): String { - return post(url, body, 3) - } - - - fun logComplete(completionResult: CharSequence) { - log( - logLevel, String.format( - "Text Completion Completion:\n\t%s", - completionResult.toString().replace("\n", "\n\t") - ) - ) - } - - fun logStart(completionRequest: CompletionRequest) { - if (completionRequest.suffix == null) { - log( - logLevel, String.format( - "Text Completion Request\nPrefix:\n\t%s\n", - completionRequest.prompt.replace("\n", "\n\t") - ) - ) - } else { - log( - logLevel, String.format( - "Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n", - completionRequest.prompt.replace("\n", "\n\t"), - completionRequest.suffix!!.replace("\n", "\n\t") - ) - ) - } - } - - @Throws(IOException::class, InterruptedException::class) - fun post(url: String, json: String, retries: Int): String { - return post(jsonRequest(url, json), retries) - } - - fun post( - request: HttpPost, - retries: Int - ): String { - try { - val client = HttpClientBuilder.create() - try { - client.build().use { httpClient -> - clients[Thread.currentThread()] = httpClient - val response: HttpResponse = httpClient.execute(request) - val entity = response.entity - return EntityUtils.toString(entity) - } - } finally { - clients.remove(Thread.currentThread()) - } - } catch (e: IOException) { - if (retries > 0) { - log.warn("Error posting request, retrying in 15 seconds", e) - Thread.sleep(15000) - return post(request, retries - 1) - } - throw e - } - } - - fun jsonRequest(url: String, json: String): HttpPost { - val request = HttpPost(url) - request.addHeader("Content-Type", "application/json") - request.addHeader("Accept", "application/json") - authorize(request) - request.entity = StringEntity(json) - return request - } - - @Throws(IOException::class) - fun authorize(request: HttpRequestBase) { - var apiKey: CharSequence = key - if (apiKey.length == 0) { - synchronized(OpenAI_API.javaClass) { - apiKey = key - if (apiKey.length == 0) { - apiKey = UITools.queryAPIKey()!! - key = apiKey.toString() - } - } - } - request.addHeader("Authorization", "Bearer $apiKey") - } - - /** - * Gets the response from the given URL. - * - * @param url The URL to GET the response from. - * @return The response from the given URL. - * @throws IOException If an I/O error occurs. - */ - @Throws(IOException::class) - operator fun get(url: String?): String { - val client = HttpClientBuilder.create() - val request = HttpGet(url) - request.addHeader("Content-Type", "application/json") - request.addHeader("Accept", "application/json") - authorize(request) - client.build().use { httpClient -> - val response: HttpResponse = httpClient.execute(request) - val entity = response.entity - return EntityUtils.toString(entity) - } - } - - fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String { - val url = apiBase + "/audio/transcriptions" - val request = HttpPost(url) - request.addHeader("Accept", "application/json") - authorize(request) - val entity = MultipartEntityBuilder.create() - entity.setMode(HttpMultipartMode.RFC6532) - entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav") - entity.addTextBody("model", "whisper-1") - if (!prompt.isEmpty()) entity.addTextBody("prompt", prompt) - request.entity = entity.build() - val response = post(request, 3) - val jsonObject = Gson().fromJson(response, JsonObject::class.java) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - return jsonObject.get("text").asString!! - } - - fun text_to_image(prompt: String = "", resolution: Int = 1024, count: Int = 1): List { - val url = apiBase + "/images/generations" - val request = HttpPost(url) - request.addHeader("Accept", "application/json") - request.addHeader("Content-Type", "application/json") - authorize(request) - val jsonObject = JsonObject() - jsonObject.addProperty("prompt", prompt) - jsonObject.addProperty("n", count) - jsonObject.addProperty("size", "${resolution}x$resolution") - request.entity = StringEntity(jsonObject.toString()) - val response = post(request, 3) - val jsonObject2 = Gson().fromJson(response, JsonObject::class.java) - if (jsonObject2.has("error")) { - val errorObject = jsonObject2.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - val dataArray = jsonObject2.getAsJsonArray("data") - val images = ArrayList() - for (i in 0 until dataArray.size()) { - images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString))) - } - return images - } - - @Throws(IOException::class) - fun processCompletionResponse(result: String): CompletionResponse { - checkError(result) - val response = mapper.readValue( - result, - CompletionResponse::class.java - ) - if (response.usage != null) { - incrementTokens(response.usage!!.total_tokens) - } - return response - } - - @Throws(IOException::class) - fun processChatResponse(result: String): ChatResponse { - checkError(result) - val response = mapper.readValue( - result, - ChatResponse::class.java - ) - if (response.usage != null) { - incrementTokens(response.usage!!.total_tokens) - } - return response - } - - private val maxTokenErrorMessage = Pattern.compile( - """This model's maximum context length is (\d+) tokens. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\). Please reduce the length of the messages or completion.""" - ) - - private fun checkError(result: String) { - try { - val jsonObject = Gson().fromJson( - result, - JsonObject::class.java - ) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - val errorMessage = errorObject["message"].asString - val matcher = maxTokenErrorMessage.matcher(errorMessage) - if (matcher.find()) { - val modelMax = matcher.group(1).toInt() - val request = matcher.group(2).toInt() - val messages = matcher.group(3).toInt() - val completion = matcher.group(4).toInt() - throw ModelMaxException(modelMax, request, messages, completion) - } - throw IOException(errorMessage) - } - } catch (e: com.google.gson.JsonSyntaxException) { - throw IOException("Invalid JSON response: $result") - } - } - - open fun incrementTokens(totalTokens: Int) {} - - companion object { - val log = Logger.getInstance(CoreAPI::class.java) - - fun log(level: LogLevel, msg: String) { - val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") - when (level) { - LogLevel.Error -> log.error(message) - LogLevel.Warn -> log.warn(message) - LogLevel.Info -> log.info(message) - else -> log.debug(message) - } - } - } - - val allowedCharset = Charset.forName("ASCII") - val mapper: ObjectMapper - get() { - val mapper = ObjectMapper() - mapper - .enable(SerializationFeature.INDENT_OUTPUT) - .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) - .enable(MapperFeature.USE_STD_BEAN_NAMING) - .setSerializationInclusion(JsonInclude.Include.NON_NULL) - .activateDefaultTyping(mapper.polymorphicTypeValidator) - return mapper - } - - fun complete( - completionRequest: CompletionRequest, - model: String - ): CompletionResponse { - logStart(completionRequest) - val completionResponse = try { - val request: String = - StringTools.restrictCharacterSet( - AsyncAPI.mapper.writeValueAsString(completionRequest), - allowedCharset - ) - val result = - post(apiBase + "/engines/" + model + "/completions", request) - processCompletionResponse(result) - } catch (e: ModelMaxException) { - completionRequest.max_tokens = (e.modelMax - e.messages) - 1 - val request: String = - StringTools.restrictCharacterSet( - AsyncAPI.mapper.writeValueAsString(completionRequest), - allowedCharset - ) - val result = - post(apiBase + "/engines/" + model + "/completions", request) - processCompletionResponse(result) - } - val completionResult = StringTools.stripPrefix( - completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' }, - completionRequest.prompt.trim { it <= ' ' }) - logComplete(completionResult) - return completionResponse - } - - fun chat( - completionRequest: ChatRequest - ): ChatResponse { - logStart(completionRequest) - val url = apiBase + "/chat/completions" - val completionResponse = try { - val result = post( - url, StringTools.restrictCharacterSet( - AsyncAPI.mapper.writeValueAsString(completionRequest), - allowedCharset - ) - ) - processChatResponse(result) - } catch (e: ModelMaxException) { - completionRequest.max_tokens = (e.modelMax - e.messages) - 1 - val result = post( - url, StringTools.restrictCharacterSet( - AsyncAPI.mapper.writeValueAsString(completionRequest), - allowedCharset - ) - ) - processChatResponse(result) - } - val completionResult = completionResponse.choices.first().message!!.content!!.trim { it <= ' ' } - logComplete(completionResult) - return completionResponse - } - - private fun logStart(completionRequest: ChatRequest) { - log( - logLevel, String.format( - "Text Completion Request\nPrefix:\n\t%s\n", - completionRequest.messages.map { it.content }.joinToString { "\n" }.replace("\n", "\n\t") - ) - ) - } - - fun moderate(text: String) { - val body: String = try { - AsyncAPI.mapper.writeValueAsString( - Map.of( - "input", - StringTools.restrictCharacterSet(text, allowedCharset) - ) - ) - } catch (e: JsonProcessingException) { - throw RuntimeException(e) - } - val result: String = try { - this.post(apiBase + "/moderations", body) - } catch (e: IOException) { - throw RuntimeException(e) - } catch (e: InterruptedException) { - throw RuntimeException(e) - } - val jsonObject = - Gson().fromJson( - result, - JsonObject::class.java - ) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - val moderationResult = - jsonObject.getAsJsonArray("results")[0].asJsonObject - AsyncAPI.log( - LogLevel.Debug, - String.format( - "Moderation Request\nText:\n%s\n\nResult:\n%s", - text.replace("\n", "\n\t"), - result - ) - ) - if (moderationResult["flagged"].asBoolean) { - val categoriesObj = - moderationResult["categories"].asJsonObject - throw RuntimeException( - ModerationException( - "Moderation flagged this request due to " + categoriesObj.keySet() - .stream().filter { c: String? -> - categoriesObj[c].asBoolean - }.reduce { a: String, b: String -> "$a, $b" } - .orElse("???") - ) - ) - } - } - -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/ChatProxy.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/ChatProxy.kt deleted file mode 100644 index e50ae166..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/ChatProxy.kt +++ /dev/null @@ -1,87 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.proxy - -import com.github.simiacryptus.aicoder.openai.core.ChatMessage -import com.github.simiacryptus.aicoder.openai.core.ChatRequest -import com.github.simiacryptus.aicoder.openai.core.CoreAPI -import com.jetbrains.rd.util.LogLevel -import java.time.LocalDateTime -import java.time.format.DateTimeFormatter - -class ChatProxy( - apiKey: String, - private val model: String = "gpt-3.5-turbo", - private val maxTokens: Int = 3500, - private val temperature: Double = 0.7, - private val verbose: Boolean = false, - base: String = "https://api.openai.com/v1", - apiLog: String = "api.${ - LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss")) - }.log.json" -) : GPTProxyBase(apiLog) { - val api: CoreAPI - - init { - api = CoreAPI(base, apiKey, LogLevel.Debug) - } - - override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { - if (verbose) println(prompt) - val request = ChatRequest() - request.messages = ( - listOf( - ChatMessage( - ChatMessage.Role.system, """ - |You are a JSON-RPC Service serving the following method: - |${prompt.methodName} - |Requests contain the following arguments: - |${prompt.argList.keys.joinToString("\n ")} - |Responses are of type: - |${prompt.responseType} - |Responses are expected to be a single JSON object - |All input arguments are optional - |""".trimMargin().trim() - ) - ) + - examples.flatMap { - listOf( - ChatMessage(ChatMessage.Role.user, argsToString(it.argList)), - ChatMessage(ChatMessage.Role.assistant, it.response) - ) - } + - listOf(ChatMessage(ChatMessage.Role.user, argsToString(prompt.argList))) - ).toTypedArray() - request.model = model - request.max_tokens = maxTokens - request.temperature = temperature - val completion = api.chat(request).response.get().toString() - if (verbose) println(completion) - val trimPrefix = trimPrefix(completion) - val trimSuffix = trimSuffix(trimPrefix.first) - return trimSuffix.first - } - - private fun trimPrefix(completion: String): Pair { - val start = completion.indexOf('{') - if (start < 0) { - return completion to "" - } else { - val substring = completion.substring(start) - return substring to completion.substring(0, start) - } - } - - private fun trimSuffix(completion: String): Pair { - val end = completion.lastIndexOf('}') - if (end < 0) { - return completion to "" - } else { - val substring = completion.substring(0, end + 1) - return substring to completion.substring(end + 1) - } - } - - private fun argsToString(argList: Map) = - "{" + argList.entries.joinToString(",\n", transform = { (argName, argValue) -> - """"$argName": $argValue""" - }) + "}" -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/CompletionProxy.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/CompletionProxy.kt deleted file mode 100644 index 897a49ea..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/CompletionProxy.kt +++ /dev/null @@ -1,46 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.proxy - -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest -import com.github.simiacryptus.aicoder.openai.core.CoreAPI -import com.jetbrains.rd.util.LogLevel - -class CompletionProxy( - apiKey: String, - private val model: String = "text-davinci-003", - private val maxTokens: Int = 1000, - private val temperature: Double = 0.7, - private val verbose: Boolean = false, - base: String = "https://api.openai.com/v1", - apiLog: String -) : GPTProxyBase(apiLog) { - val api: CoreAPI - - init { - api = CoreAPI(base, apiKey, LogLevel.Debug) - } - - override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { - if(verbose) println(prompt) - val request = CompletionRequest() - val argList = prompt.argList - val methodName = prompt.methodName - val responseType = prompt.responseType - request.prompt = """ - Method: $methodName - Response Type: - ${responseType.replace("\n", "\n ")} - Request: - { - ${argList.entries.joinToString(",\n", transform = { (argName, argValue) -> - """"$argName": $argValue""" - }).replace("\n", "\n ")} - } - Response: - {""".trimIndent() - request.max_tokens = maxTokens - request.temperature = temperature - val completion = api.complete(request, model).firstChoice.get().toString() - if(verbose) println(completion) - return "{$completion" - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/GPTProxyBase.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/GPTProxyBase.kt deleted file mode 100644 index f595e6c5..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/GPTProxyBase.kt +++ /dev/null @@ -1,210 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.proxy - -import com.fasterxml.jackson.core.JsonParseException -import com.fasterxml.jackson.core.json.JsonReadFeature -import com.fasterxml.jackson.databind.ObjectMapper -import com.github.simiacryptus.aicoder.util.StringTools.indentJoin -import java.io.BufferedWriter -import java.io.File -import java.io.FileWriter -import java.lang.reflect.* - -abstract class GPTProxyBase( - apiLogFile: String -) { - abstract fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String - - fun create(clazz: Class): T { - return Proxy.newProxyInstance(clazz.classLoader, arrayOf(clazz)) { proxy, method, args -> - if (method.name == "toString") return@newProxyInstance clazz.simpleName - val prompt = ProxyRequest( - method.name, - typeToName(method.returnType), - (args ?: arrayOf()).zip(method.parameters) - .filter> { (arg: Any?, _) -> arg != null } - .map, Pair> { (arg, param) -> - param.name to toJson(arg!!) - }.toMap()) - for (retry in 0 until 3) { - val result = complete(prompt, *examples[method.name]?.toTypedArray() ?: arrayOf()) - writeToJsonLog(ProxyRecord(prompt.methodName, prompt.argList, result)) - try { - return@newProxyInstance fromJson(result, method.returnType) - } catch (e: JsonParseException) { - println("Failed to parse response: $result") - println("Retrying...") - } - } - } as T - } - - private val apiLog = openApiLog(apiLogFile) - private val examples = HashMap>() - private fun loadExamples(file: File = File("api.examples.json")) : List { - if (!file.exists()) return listOf() - val json = file.readText() - return fromJson(json, object : ArrayList() {}.javaClass) - } - fun addExamples(file: File) { - examples.putAll(loadExamples(file).groupBy { it.methodName }) - } - - private fun openApiLog(file: String): BufferedWriter { - val writer = BufferedWriter(FileWriter(File(file))) - writer.write("[") - writer.newLine() - writer.flush() - return writer - } - private fun writeToJsonLog(record: ProxyRecord) { - apiLog.write(toJson(record)) - apiLog.write(",") - apiLog.newLine() - apiLog.flush() - } - - open fun toJson(data: Any): String { - return objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(data) - } - - open fun fromJson(data: String, type: Class): T { - if (data.isNotEmpty()) try { - if (type.isAssignableFrom(String::class.java)) return data as T - return objectMapper().readValue(data, type) - } - catch (e: JsonParseException) { -// log.error("Error parsing JSON", e) - throw e - } - catch (e: Exception) { - log.error("Error parsing JSON", e) - } - return type.getConstructor().newInstance() - } - - open fun objectMapper(): ObjectMapper { - return ObjectMapper() - .enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS.mappedFeature()) - .configure(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - } - - data class ProxyRequest( - val methodName: String = "", - val responseType: String = "", - val argList: Map = mapOf() - ) - - data class ProxyRecord( - val methodName: String = "", - val argList: Map = mapOf(), - val response: String = "" - ) - - companion object { - val log = org.slf4j.LoggerFactory.getLogger(GPTProxyBase::class.java) - - fun typeToName(type: Class<*>?): String { - // Convert a type to API documentation including type name and type structure, recusively expanding child types - if (type == null) { - return "null" - } - if (type.isPrimitive) { - return type.simpleName - } - if (type.isArray) { - return "Array<${typeToName(type.componentType)}>" - } - if (type.isEnum) { - return type.simpleName - } - if (type.isAssignableFrom(List::class.java)) { - if (type.isAssignableFrom(ParameterizedType::class.java)) { - val genericType = (type as ParameterizedType).actualTypeArguments[0] - return "List<${typeToName(genericType as Class<*>)}>" - } else { - return "List" - } - } - if (type.isAssignableFrom(Map::class.java)) { - val keyType = (type as ParameterizedType).actualTypeArguments[0] - val valueType = (type as ParameterizedType).actualTypeArguments[1] - return "Map<${typeToName(keyType as Class<*>)}, ${typeToName(valueType as Class<*>)}>" - } - if (type.getPackage()?.name?.startsWith("java") == true) { - return type.simpleName - } - val typeDescription = typeDescription(type) - return typeDescription.toString() - } - - private fun typeDescription(clazz: Class<*>): TypeDescription { - val apiDocumentation = if (clazz.isArray) { - return TypeDescription("Array<${typeDescription(clazz.componentType)}>") - } else if (clazz.isAssignableFrom(List::class.java)) { - if (clazz.isAssignableFrom(ParameterizedType::class.java)) { - val genericType = (clazz as ParameterizedType).actualTypeArguments[0] - return TypeDescription("List<${typeDescription(genericType as Class<*>)}>") - } else { - return TypeDescription("List") - } - } else if (clazz.isAssignableFrom(Map::class.java)) { - val keyType = (clazz as ParameterizedType).actualTypeArguments[0] - val valueType = (clazz as ParameterizedType).actualTypeArguments[1] - return TypeDescription("Map<${typeDescription(keyType as Class<*>)}, ${typeDescription(valueType as Class<*>)}}>") - } else if (clazz == String::class.java) { - return TypeDescription(clazz.simpleName) - } else { - TypeDescription(clazz.simpleName) - } - if (clazz.isPrimitive) return apiDocumentation - if (clazz.isEnum) return apiDocumentation - - for (field in clazz.declaredFields) { - if (field.name.startsWith("\$")) continue - // Get ParameterizedType for field - val type = field.genericType - if (type is ParameterizedType) { - // Get raw type - val rawType = type.rawType as Class<*> - if (rawType.isAssignableFrom(List::class.java)) { - // Get type of list elements - val elementType = type.actualTypeArguments[0] as Class<*> - apiDocumentation.fields.add( - FieldData( - field.name, - TypeDescription("List<${typeDescription(elementType)}>") - ) - ) - continue - } - } - apiDocumentation.fields.add(FieldData(field.name, typeDescription(field.type))) - } - return apiDocumentation - } - - private fun typeDescription(clazz: Type): TypeDescription { - if (clazz is Class<*>) return typeDescription(clazz) - if (clazz is ParameterizedType) { - val rawType = clazz.rawType as Class<*> - if (rawType.isAssignableFrom(List::class.java)) { - // Get type of list elements - val elementType = clazz.actualTypeArguments[0] as Class<*> - return TypeDescription("List<${typeDescription(elementType)}>") - } - } - return TypeDescription(clazz.typeName) - } - - class TypeDescription(val name: String) { - val fields: ArrayList = ArrayList() - override fun toString(): String { - return if (fields.isEmpty()) name else indentJoin(fields) - } - } - - class FieldData(val name: String, val type: TypeDescription) { - override fun toString(): String = """"$name": $type""" - } - } -} diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/SoftwareProjectAI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/SoftwareProjectAI.kt deleted file mode 100644 index 0d4005e5..00000000 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/proxy/SoftwareProjectAI.kt +++ /dev/null @@ -1,76 +0,0 @@ -package com.github.simiacryptus.aicoder.openai.proxy - -interface SoftwareProjectAI { - fun newProject(description: String): Project - - data class Project( - val name: String = "", - val description: String = "", - val language: String = "", - val libraries: List = listOf(), - val buildTools: List = listOf(), - ) - - fun getProjectStatements(project: Project): ProjectStatements - - data class ProjectStatements( - val assumptions: List = listOf(), - val requirements: List = listOf(), - val risks: List = listOf(), - ) - - fun buildProjectDesign(project: Project, requirements: ProjectStatements): ProjectDesign - - data class ProjectDesign( - val designDetails: List = listOf(), - val tests: List = listOf(), - ) - - fun buildProjectFileSpecifications(project: Project, requirements: ProjectStatements, design: ProjectDesign, recursive: Boolean = true): FileList - - data class FileList( - val files: List = listOf(), - ) - - data class FileSpecification( - val location: FilePath = FilePath(), - val description: String = "", - val requires: List = listOf(), - val publicProperties: List = listOf(), - val publicMethodSignatures: List = listOf(), - ) - - data class FilePath( - val path: String = "", - val name: String = "", - val extension: String = "", - ) { - override fun toString(): String { - return "${path.trimEnd('/')}/$name.$extension" - } - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false - other as FilePath - if (path != other.path) return false - if (name != other.name) return false - if (extension != other.extension) return false - return true - } - override fun hashCode(): Int { - var result = path.hashCode() - result = 31 * result + name.hashCode() - result = 31 * result + extension.hashCode() - return result - } - - } - - fun implement(project: Project, imports: List, specification: FileSpecification): SourceCode - - data class SourceCode( - val language: String = "", - val code: String = "", - ) - -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.kt index d62541ab..a97c9ab4 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.kt @@ -1,6 +1,6 @@ package com.github.simiacryptus.aicoder.openai.translate -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest +import com.github.simiacryptus.openai.CompletionRequest import java.util.* diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.kt index 1de7f486..4db50ec4 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.kt @@ -1,7 +1,7 @@ package com.github.simiacryptus.aicoder.openai.translate import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest +import com.github.simiacryptus.openai.CompletionRequest import java.util.* import java.util.stream.Collectors diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CompletionRequestWithModel.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CompletionRequestWithModel.kt index 38ea5907..86c2c033 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CompletionRequestWithModel.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CompletionRequestWithModel.kt @@ -1,7 +1,7 @@ package com.github.simiacryptus.aicoder.openai.ui import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest +import com.github.simiacryptus.openai.CompletionRequest class CompletionRequestWithModel : CompletionRequest { var model: String diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveCompletionRequest.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveCompletionRequest.kt index c0d47cbf..2b05e1c9 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveCompletionRequest.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveCompletionRequest.kt @@ -3,8 +3,8 @@ package com.github.simiacryptus.aicoder.openai.ui import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.config.Name import com.github.simiacryptus.aicoder.openai.async.AsyncAPI -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.openai.CompletionRequest import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.Futures import com.intellij.ui.components.JBScrollPane diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveEditRequest.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveEditRequest.kt index bf8030c5..9f9db04d 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveEditRequest.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/InteractiveEditRequest.kt @@ -2,8 +2,8 @@ package com.github.simiacryptus.aicoder.openai.ui import com.github.simiacryptus.aicoder.config.Name import com.github.simiacryptus.aicoder.openai.async.AsyncAPI -import com.github.simiacryptus.aicoder.openai.core.EditRequest import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.openai.EditRequest import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.Futures import com.intellij.ui.components.JBScrollPane diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt similarity index 77% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt rename to src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt index 88b3881a..30c88a5f 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt @@ -1,11 +1,11 @@ package com.github.simiacryptus.aicoder.openai.ui import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.aicoder.openai.core.CoreAPI +import com.github.simiacryptus.openai.OpenAIClient -class CoreAPIImpl( +class OpenAIClientImpl( private val appSettingsState: AppSettingsState -) : CoreAPI( +) : OpenAIClient( appSettingsState.apiBase, appSettingsState.apiKey, appSettingsState.apiLogLevel diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt index 490234df..631aac5a 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt @@ -1,16 +1,12 @@ package com.github.simiacryptus.aicoder.openai.ui -import com.fasterxml.jackson.databind.node.ObjectNode import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.async.AsyncAPI import com.github.simiacryptus.aicoder.openai.async.AsyncAPI.Companion.map import com.github.simiacryptus.aicoder.openai.async.AsyncAPIImpl -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest -import com.github.simiacryptus.aicoder.openai.core.CompletionResponse -import com.github.simiacryptus.aicoder.openai.core.CoreAPI -import com.github.simiacryptus.aicoder.openai.core.EditRequest import com.github.simiacryptus.aicoder.util.IndentedText -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools +import com.github.simiacryptus.openai.* import com.google.common.util.concurrent.ListenableFuture import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project @@ -29,10 +25,10 @@ object OpenAI_API { private val log = Logger.getInstance(OpenAI_API::class.java) @JvmStatic - val coreAPI: CoreAPI = CoreAPIImpl(settingsState!!) + val openAIClient: OpenAIClient = OpenAIClientImpl(settingsState!!) @JvmStatic - val asyncAPI = AsyncAPIImpl(coreAPI, settingsState!!) + val asyncAPI = AsyncAPIImpl(openAIClient, settingsState!!) private val activeModelUI = WeakHashMap, Any>() @@ -63,16 +59,10 @@ object OpenAI_API { activeModelUI[comboBox] = Any() AsyncAPI.onSuccess( engines - ) { engines: ObjectNode -> - val data = engines["data"] - val items = - arrayOfNulls(data.size()) - for (i in 0 until data.size()) { - items[i] = data[i]["id"].asText() - } - Arrays.sort(items) + ) { engines: Array -> + Arrays.sort(engines) activeModelUI.keys.forEach(Consumer { ui: ComboBox -> - Arrays.stream(items).forEach(ui::addItem) + Arrays.stream(engines).forEach(ui::addItem) }) } return comboBox!! @@ -99,6 +89,14 @@ object OpenAI_API { return map(complete(project, request)) { it.firstChoice.map(filter).orElse("") } } + fun getChat( + project: Project?, + request: ChatRequest, + filter: (CharSequence) -> CharSequence = { it } + ): ListenableFuture { + return map(chat(project, request)) { it.response.map(filter).orElse("") } + } + fun edit(project: Project?, request: EditRequest, indent: CharSequence): ListenableFuture { return edit(project, request, filterStringResult(indent)) } @@ -118,12 +116,9 @@ object OpenAI_API { } return settings } - private val engines: ListenableFuture - get() = AsyncAPI.pool.submit { - coreAPI.mapper.readValue( - coreAPI.get(settingsState!!.apiBase + "/engines"), - ObjectNode::class.java - ) + private val engines: ListenableFuture> + get() = AsyncAPI.pool.submit> { + openAIClient.getEngines() } fun complete( @@ -141,6 +136,17 @@ object OpenAI_API { return edit(project, request, settingsState!!) } + fun chat( + project: Project?, + chatRequest: ChatRequest + ): ListenableFuture { + val settings = settingsState + val withModel = chatRequest.uiIntercept() + //withModel.fixup(settings!!) + val newRequest = ChatRequest(withModel) + return asyncAPI.chat(project, newRequest, settings) + } + private fun edit( project: Project?, editRequest: EditRequest, @@ -169,6 +175,6 @@ object OpenAI_API { } } - fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String = coreAPI.text_to_speech(wavAudio, prompt) + fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String = openAIClient.dictate(wavAudio, prompt) } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/BlockComment.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/BlockComment.kt index 872d6d55..4b71c32d 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/BlockComment.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/BlockComment.kt @@ -2,6 +2,7 @@ package com.github.simiacryptus.aicoder.util import com.github.simiacryptus.aicoder.util.TextBlock.Companion.DELIMITER import com.github.simiacryptus.aicoder.util.TextBlock.Companion.TAB_REPLACEMENT +import com.github.simiacryptus.util.StringTools import java.util.* import java.util.stream.Collectors diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/IndentedText.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/IndentedText.kt index 09971da5..962e02c3 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/IndentedText.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/IndentedText.kt @@ -1,8 +1,8 @@ package com.github.simiacryptus.aicoder.util -import com.github.simiacryptus.aicoder.util.StringTools.getWhitespacePrefix -import com.github.simiacryptus.aicoder.util.StringTools.getWhitespacePrefix2 -import com.github.simiacryptus.aicoder.util.StringTools.stripPrefix +import com.github.simiacryptus.util.StringTools.getWhitespacePrefix +import com.github.simiacryptus.util.StringTools.getWhitespacePrefix2 +import com.github.simiacryptus.util.StringTools.stripPrefix import java.util.* import java.util.stream.Collectors diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/LineComment.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/LineComment.kt index ae9346d5..0f78c55f 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/LineComment.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/LineComment.kt @@ -1,8 +1,8 @@ package com.github.simiacryptus.aicoder.util -import com.github.simiacryptus.aicoder.util.StringTools.getWhitespacePrefix -import com.github.simiacryptus.aicoder.util.StringTools.stripPrefix -import com.github.simiacryptus.aicoder.util.StringTools.trimPrefix +import com.github.simiacryptus.util.StringTools.getWhitespacePrefix +import com.github.simiacryptus.util.StringTools.stripPrefix +import com.github.simiacryptus.util.StringTools.trimPrefix import java.util.* import java.util.function.Function import java.util.stream.Collectors diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/StyleUtil.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/StyleUtil.kt index 384447a7..4d138f1f 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/StyleUtil.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/StyleUtil.kt @@ -4,7 +4,7 @@ import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.async.AsyncAPI import com.github.simiacryptus.aicoder.openai.async.AsyncAPI.Companion.map import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API.getCompletion -import com.github.simiacryptus.aicoder.util.StringTools.lineWrapping +import com.github.simiacryptus.util.StringTools.lineWrapping import com.google.common.util.concurrent.ListenableFuture import com.intellij.openapi.diagnostic.Logger import java.util.* diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/UITools.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/UITools.kt index 4107cefd..a86a5c0a 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/UITools.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/UITools.kt @@ -6,10 +6,11 @@ import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.config.Name import com.github.simiacryptus.aicoder.openai.* import com.github.simiacryptus.aicoder.openai.async.AsyncAPI -import com.github.simiacryptus.aicoder.openai.core.CompletionRequest -import com.github.simiacryptus.aicoder.openai.core.EditRequest -import com.github.simiacryptus.aicoder.openai.core.ModerationException import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API +import com.github.simiacryptus.openai.ChatRequest +import com.github.simiacryptus.openai.CompletionRequest +import com.github.simiacryptus.openai.EditRequest +import com.github.simiacryptus.openai.ModerationException import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture @@ -20,11 +21,12 @@ import com.intellij.openapi.command.WriteCommandAction import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret import com.intellij.openapi.editor.Document -import com.intellij.openapi.fileChooser.FileChooser -import com.intellij.openapi.fileChooser.FileChooserDescriptorFactory import com.intellij.openapi.fileEditor.FileDocumentManager import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager +import com.intellij.openapi.progress.Task +import com.intellij.openapi.progress.util.AbstractProgressIndicatorBase +import com.intellij.openapi.project.Project import com.intellij.openapi.ui.ComboBox import com.intellij.openapi.ui.DialogWrapper import com.intellij.openapi.util.TextRange @@ -69,6 +71,15 @@ object UITools { redoableRequest(request, indent, event, { x: CharSequence -> x }, action) } + fun redoableRequest( + request: ChatRequest, + indent: CharSequence, + event: AnActionEvent, + action: Function + ) { + redoableRequest(request, indent, event, { x: CharSequence -> x }, action) + } + fun startProgress(): ProgressIndicator? { if (1 == 1) return null if (AppSettingsState.instance.suppressProgress) return null @@ -143,6 +154,42 @@ object UITools { }, AsyncAPI.pool) } + fun redoableRequest( + request: ChatRequest, + indent: CharSequence, + event: AnActionEvent, + transformCompletion: Function, + action: Function, + resultFuture: ListenableFuture = OpenAI_API.getChat(event.project!!, request), + progressIndicator: ProgressIndicator? = startProgress() + ) { + Futures.addCallback(resultFuture, object : FutureCallback { + override fun onSuccess(result: CharSequence?) { + progressIndicator?.cancel() + val actionFn = AtomicReference() + WriteCommandAction.runWriteCommandAction(event.project) { + actionFn.set( + action.apply( + transformCompletion.apply( + result.toString() + ) + ) + ) + } + if (null != actionFn.get()) { + val undo = getRetry(request, indent, event, action, actionFn.get()!!, transformCompletion) + val document = event.getRequiredData(CommonDataKeys.EDITOR).document + retry[document] = undo + } + } + + override fun onFailure(t: Throwable) { + progressIndicator?.cancel() + handle(t) + } + }, AsyncAPI.pool) + } + /** * Get a retry Runnable for the given [CompletionRequest]. * @@ -198,6 +245,46 @@ object UITools { } } + fun getRetry( + request: ChatRequest, + indent: CharSequence?, + event: AnActionEvent, + action: Function, + undo: Runnable, + transformCompletion: Function + ): Runnable { + val document = Objects.requireNonNull(event.getData(CommonDataKeys.EDITOR))!!.document + return Runnable { + val progressIndicator = startProgress() + Futures.addCallback( + OpenAI_API.getChat(event.project!!, request) { it }, + object : FutureCallback { + override fun onSuccess(result: CharSequence?) { + progressIndicator?.cancel() + WriteCommandAction.runWriteCommandAction(event.project) { undo?.run() } + val nextUndo = AtomicReference() + WriteCommandAction.runWriteCommandAction(event.project) { + nextUndo.set( + action.apply( + transformCompletion.apply(result.toString()) + ) + ) + } + retry[document] = + getRetry(request, indent, event, action, nextUndo.get()!!, transformCompletion) + } + + override fun onFailure(t: Throwable) { + progressIndicator?.cancel() + handle(t) + } + }, + AsyncAPI.pool + ) + } + } + + fun redoableRequest( request: EditRequest, indent: CharSequence, @@ -926,7 +1013,12 @@ object UITools { return formBuilder.addComponentFillVertically(JPanel(), 0).panel } - fun showDialog(e: AnActionEvent, uiClass: Class, configClass: Class, onComplete: (C) -> Unit) { + fun showDialog( + e: AnActionEvent, + uiClass: Class, + configClass: Class, + onComplete: (C) -> Unit + ) { val project = e.project val component = uiClass.getConstructor().newInstance() val config = configClass.getConstructor().newInstance() @@ -969,4 +1061,58 @@ object UITools { return project?.baseDir } + fun isInterruptedException(e: Throwable?): Boolean { + if (e is InterruptedException) return true + return if (e!!.cause != null && e.cause !== e) isInterruptedException( e.cause ) else false + } + + fun run( + project : Project?, + title: String, + canBeCancelled: Boolean, + retries: Int = 3, + suppressProgress: Boolean = false, + task: (ProgressIndicator) -> T + ): T + { + return run(object : Task.WithResult(project, title, canBeCancelled) { + override fun compute(indicator: ProgressIndicator): T { + return task(indicator) + } + }, retries, suppressProgress) + } + + fun run(task: Task.WithResult, retries: Int = 3, suppressProgress: Boolean = false): T { + return try { + if (!suppressProgress) { + ProgressManager.getInstance().run(task) + } else { + task.run(AbstractProgressIndicatorBase()) + task.result + } + } catch (e: RuntimeException) { + if (isInterruptedException(e)) throw e + if (retries > 0) { + AsyncAPI.log.warn("Retrying request", e) + run(task = task, retries - 1) + } else { + throw e + } + } catch (e: InterruptedException) { + throw RuntimeException(e) + } catch (e: Exception) { + if (isInterruptedException(e)) throw RuntimeException(e) + if (retries > 0) { + AsyncAPI.log.warn("Retrying request", e) + try { + Thread.sleep(15000) + } catch (ex: InterruptedException) { + Thread.currentThread().interrupt() + } + run(task = task, retries - 1) + } else { + throw RuntimeException(e) + } + } + } } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/psi/PsiUtil.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/util/psi/PsiUtil.kt index 45ed42f6..0bd9ece0 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/psi/PsiUtil.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/util/psi/PsiUtil.kt @@ -1,6 +1,6 @@ package com.github.simiacryptus.aicoder.util.psi -import com.github.simiacryptus.aicoder.util.StringTools +import com.github.simiacryptus.util.StringTools import com.intellij.lang.Language import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ApiError.kt b/src/main/kotlin/com/github/simiacryptus/openai/ApiError.kt similarity index 86% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ApiError.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ApiError.kt index 8e417cf3..ae3a33f7 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ApiError.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ApiError.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class ApiError { @Suppress("unused") diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatChoice.kt b/src/main/kotlin/com/github/simiacryptus/openai/ChatChoice.kt similarity index 76% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatChoice.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ChatChoice.kt index 8f135643..4ad40771 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatChoice.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ChatChoice.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class ChatChoice @Suppress("unused") constructor() { var message: ChatMessage? = null diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatMessage.kt b/src/main/kotlin/com/github/simiacryptus/openai/ChatMessage.kt similarity index 81% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatMessage.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ChatMessage.kt index 018376b5..d6a367ea 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatMessage.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ChatMessage.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class ChatMessage { enum class Role { diff --git a/src/main/kotlin/com/github/simiacryptus/openai/ChatRequest.kt b/src/main/kotlin/com/github/simiacryptus/openai/ChatRequest.kt new file mode 100644 index 00000000..4c4c4198 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/ChatRequest.kt @@ -0,0 +1,28 @@ +package com.github.simiacryptus.openai + +import com.github.simiacryptus.aicoder.config.AppSettingsState + +class ChatRequest @Suppress("unused") constructor() { + fun uiIntercept(): ChatRequest { + return this + } + + constructor(settingsState: AppSettingsState) : this() { + model = (settingsState.model_chat) + temperature = (settingsState.temperature) + } + + constructor(request: ChatRequest) : this() { + model = (request.model) + temperature = (request.temperature) + max_tokens = (request.max_tokens) + stop = (request.stop) + messages = (request.messages) + } + + var messages = arrayOf() + var model: String? = null + var temperature = 0.0 + var max_tokens = 0 + var stop: Array? = null +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatResponse.kt b/src/main/kotlin/com/github/simiacryptus/openai/ChatResponse.kt similarity index 90% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatResponse.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ChatResponse.kt index 9243412b..1bffee55 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ChatResponse.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ChatResponse.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import java.util.* diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionChoice.kt b/src/main/kotlin/com/github/simiacryptus/openai/CompletionChoice.kt similarity index 87% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionChoice.kt rename to src/main/kotlin/com/github/simiacryptus/openai/CompletionChoice.kt index c3f3a705..584fc730 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionChoice.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/CompletionChoice.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class CompletionChoice { var text: String? = null diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionRequest.kt b/src/main/kotlin/com/github/simiacryptus/openai/CompletionRequest.kt similarity index 95% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionRequest.kt rename to src/main/kotlin/com/github/simiacryptus/openai/CompletionRequest.kt index 6b6fa4f1..5d982d83 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionRequest.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/CompletionRequest.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.ui.CompletionRequestWithModel diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionResponse.kt b/src/main/kotlin/com/github/simiacryptus/openai/CompletionResponse.kt similarity index 91% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionResponse.kt rename to src/main/kotlin/com/github/simiacryptus/openai/CompletionResponse.kt index 8c3241a1..1b8bd04b 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/CompletionResponse.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/CompletionResponse.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import java.util.* diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/EditRequest.kt b/src/main/kotlin/com/github/simiacryptus/openai/EditRequest.kt similarity index 95% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/EditRequest.kt rename to src/main/kotlin/com/github/simiacryptus/openai/EditRequest.kt index 6169982b..e22cff26 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/EditRequest.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/EditRequest.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.ui.InteractiveEditRequest diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Engine.kt b/src/main/kotlin/com/github/simiacryptus/openai/Engine.kt similarity index 90% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Engine.kt rename to src/main/kotlin/com/github/simiacryptus/openai/Engine.kt index de1c9ab7..adb7e830 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Engine.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/Engine.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class Engine { @Suppress("unused") diff --git a/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt b/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt new file mode 100644 index 00000000..d4e193f1 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt @@ -0,0 +1,138 @@ +package com.github.simiacryptus.openai + +import com.google.common.util.concurrent.ListeningExecutorService +import com.google.common.util.concurrent.ListeningScheduledExecutorService +import com.google.common.util.concurrent.MoreExecutors +import com.google.common.util.concurrent.ThreadFactoryBuilder +import com.intellij.openapi.diagnostic.Logger +import org.apache.http.impl.client.CloseableHttpClient +import org.apache.http.impl.client.HttpClientBuilder +import java.io.IOException +import java.time.Duration +import java.util.* +import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.math.pow + +@Suppress("MemberVisibilityCanBePrivate") +open class HttpClientManager { + + companion object { + val log = Logger.getInstance(OpenAIClient::class.java) + val threadFactory: ThreadFactory = ThreadFactoryBuilder().setNameFormat("API Thread %d").build() + val scheduledPool: ListeningScheduledExecutorService = + MoreExecutors.listeningDecorator(ScheduledThreadPoolExecutor(4, threadFactory)) + val workPool: ListeningExecutorService = + MoreExecutors.listeningDecorator( + ThreadPoolExecutor( + 1, 32, + 0, TimeUnit.MILLISECONDS, LinkedBlockingQueue(), threadFactory + ) + ) + + fun withPool(fn: () -> T): T = workPool.submit(Callable { + return@Callable fn() + }).get() + + fun withExpBackoffRetry(retryCount: Int = 3, sleepScale: Long = 1000L, fn: () -> T): T { + var lastException: Exception? = null + for (i in 0 until retryCount) { + try { + return fn() + } catch (e: Exception) { + lastException = e + log.info("Request failed; retrying ($i/$retryCount): " + e.message) + Thread.sleep(sleepScale * 2.0.pow(i.toDouble()).toLong()) + } + } + throw lastException!! + } + + } + + protected val clients: MutableMap = WeakHashMap() + fun getClient(thread: Thread = Thread.currentThread()): CloseableHttpClient = + if (thread in clients) clients[thread]!! + else synchronized(clients) { + val client = HttpClientBuilder.create().build() + clients[thread] = client + client + } + + fun closeClient(thread: Thread) { + try { + synchronized(clients) { + clients[thread] + }?.close() + } catch (e: IOException) { + log.info("Error closing client: " + e.message) + } + } + + fun withCancellationMonitor(fn: () -> T, cancelCheck: () -> Boolean = { Thread.currentThread().isInterrupted }): T { + val thread = Thread.currentThread() + val isCompleted = AtomicBoolean(false) + val start = Date() + val future = scheduledPool.scheduleAtFixedRate({ + if (cancelCheck()) { + log.info("Request cancelled at ${Date()} (started $start); closing client for thread $thread") + closeClient(thread, isCompleted) + } + }, 0, 10, TimeUnit.MILLISECONDS) + try { + return fn() + } finally { + isCompleted.set(true) + future.cancel(false) + } + } + + fun withTimeout(duration: Duration, fn: () -> T): T { + val thread = Thread.currentThread() + val isCompleted = AtomicBoolean(false) + val start = Date() + val future = scheduledPool.schedule({ + log.info("Request timed out after $duration at ${Date()} (started $start); closing client for thread $thread") + closeClient(thread, isCompleted) + }, duration.toMillis(), TimeUnit.MILLISECONDS) + try { + return fn() + } finally { + isCompleted.set(true) + future.cancel(false) + } + } + + private fun closeClient(thread: Thread, isCompleted: AtomicBoolean) { + closeClient(thread) + Thread.sleep(10) + while (isCompleted.get()) { + Thread.sleep(5000) + if (isCompleted.get()) break + log.info("Request still not completed; thread stack: \n\t${thread.stackTrace.joinToString { "\n\t" }}\nkilling thread $thread") + @Suppress("DEPRECATION") + thread.stop() + } + } + + fun withReliability(requestTimeoutSeconds: Long = 180, retryCount: Int = 3, fn: () -> T): T = + withExpBackoffRetry(retryCount) { +// withPool { +// } + withTimeout(Duration.ofSeconds(requestTimeoutSeconds)) { + withCancellationMonitor(fn) + } + } + + fun withPerformanceLogging(fn: () -> T):T { + val start = Date() + try { + return fn() + } finally { + log.debug("Request completed in ${Date().time - start.time}ms") + } + } + + + +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/LogProbs.kt b/src/main/kotlin/com/github/simiacryptus/openai/LogProbs.kt similarity index 90% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/LogProbs.kt rename to src/main/kotlin/com/github/simiacryptus/openai/LogProbs.kt index 28339670..04d84141 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/LogProbs.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/LogProbs.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import com.fasterxml.jackson.databind.node.ObjectNode diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModelMaxException.kt b/src/main/kotlin/com/github/simiacryptus/openai/ModelMaxException.kt similarity index 80% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModelMaxException.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ModelMaxException.kt index 154283f0..19605b51 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModelMaxException.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ModelMaxException.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai import java.io.IOException diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModerationException.kt b/src/main/kotlin/com/github/simiacryptus/openai/ModerationException.kt similarity index 54% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModerationException.kt rename to src/main/kotlin/com/github/simiacryptus/openai/ModerationException.kt index e9bde0f0..661b47d6 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/ModerationException.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/ModerationException.kt @@ -1,3 +1,3 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class ModerationException(message: String?) : Exception(message) \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt b/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt new file mode 100644 index 00000000..bef77aa3 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt @@ -0,0 +1,472 @@ +package com.github.simiacryptus.openai + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.databind.MapperFeature +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import com.fasterxml.jackson.databind.node.ObjectNode +import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API +import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.util.StringTools +import com.google.gson.Gson +import com.google.gson.JsonObject +import com.intellij.openapi.diagnostic.Logger +import com.jetbrains.rd.util.LogLevel +import org.apache.http.HttpResponse +import org.apache.http.client.methods.HttpGet +import org.apache.http.client.methods.HttpPost +import org.apache.http.client.methods.HttpRequestBase +import org.apache.http.entity.ContentType +import org.apache.http.entity.StringEntity +import org.apache.http.entity.mime.HttpMultipartMode +import org.apache.http.entity.mime.MultipartEntityBuilder +import org.apache.http.util.EntityUtils +import java.awt.image.BufferedImage +import java.io.IOException +import java.net.URL +import java.nio.charset.Charset +import java.util.* +import java.util.Map +import java.util.regex.Pattern +import javax.imageio.ImageIO +import kotlin.collections.ArrayList +import kotlin.collections.List +import kotlin.collections.first +import kotlin.collections.set + +@Suppress("unused") +open class OpenAIClient( + private val apiBase: String, + var key: String, + private val logLevel: LogLevel +) : HttpClientManager() { + + fun getEngines(): Array { + val engines = mapper.readValue( + get(OpenAI_API.settingsState!!.apiBase + "/engines"), + ObjectNode::class.java + ) + val data = engines["data"] + val items = + arrayOfNulls(data.size()) + for (i in 0 until data.size()) { + items[i] = data[i]["id"].asText() + } + Arrays.sort(items) + return items + } + + private fun logComplete(completionResult: CharSequence) { + log( + logLevel, String.format( + "Chat Completion:\n\t%s", + completionResult.toString().replace("\n", "\n\t") + ) + ) + } + + private fun logStart(completionRequest: CompletionRequest) { + if (completionRequest.suffix == null) { + log( + logLevel, String.format( + "Text Completion Request\nPrefix:\n\t%s\n", + completionRequest.prompt.replace("\n", "\n\t") + ) + ) + } else { + log( + logLevel, String.format( + "Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n", + completionRequest.prompt.replace("\n", "\n\t"), + completionRequest.suffix!!.replace("\n", "\n\t") + ) + ) + } + } + + @Throws(IOException::class, InterruptedException::class) + protected fun post(url: String, json: String): String { + return post(jsonRequest(url, json)) + } + + private fun post( + request: HttpPost + ): String { + val client = getClient() + try { + client.use { httpClient -> + synchronized(clients) { + clients[Thread.currentThread()] = httpClient + } + val response: HttpResponse = httpClient.execute(request) + val entity = response.entity + return EntityUtils.toString(entity) + } + } finally { + synchronized(clients) { + clients.remove(Thread.currentThread()) + } + } + } + + private fun jsonRequest(url: String, json: String): HttpPost { + val request = HttpPost(url) + request.addHeader("Content-Type", "application/json") + request.addHeader("Accept", "application/json") + authorize(request) + request.entity = StringEntity(json) + return request + } + + @Throws(IOException::class) + protected fun authorize(request: HttpRequestBase) { + var apiKey: CharSequence = key + if (apiKey.length == 0) { + synchronized(OpenAI_API.javaClass) { + apiKey = key + if (apiKey.length == 0) { + apiKey = UITools.queryAPIKey()!! + key = apiKey.toString() + } + } + } + request.addHeader("Authorization", "Bearer $apiKey") + } + + /** + * Gets the response from the given URL. + * + * @param url The URL to GET the response from. + * @return The response from the given URL. + * @throws IOException If an I/O error occurs. + */ + @Throws(IOException::class) + protected operator fun get(url: String?): String { + val client = getClient() + val request = HttpGet(url) + request.addHeader("Content-Type", "application/json") + request.addHeader("Accept", "application/json") + authorize(request) + client.use { httpClient -> + val response: HttpResponse = httpClient.execute(request) + val entity = response.entity + return EntityUtils.toString(entity) + } + } + + fun dictate(wavAudio: ByteArray, prompt: String = ""): String = withReliability { + withPerformanceLogging { + val url = apiBase + "/audio/transcriptions" + val request = HttpPost(url) + request.addHeader("Accept", "application/json") + authorize(request) + val entity = MultipartEntityBuilder.create() + entity.setMode(HttpMultipartMode.RFC6532) + entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav") + entity.addTextBody("model", "whisper-1") + if (!prompt.isEmpty()) entity.addTextBody("prompt", prompt) + request.entity = entity.build() + val response = post(request) + val jsonObject = Gson().fromJson(response, JsonObject::class.java) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + jsonObject.get("text").asString!! + } + } + + fun render(prompt: String = "", resolution: Int = 1024, count: Int = 1): List = withReliability { + withPerformanceLogging { + val url = apiBase + "/images/generations" + val request = HttpPost(url) + request.addHeader("Accept", "application/json") + request.addHeader("Content-Type", "application/json") + authorize(request) + val jsonObject = JsonObject() + jsonObject.addProperty("prompt", prompt) + jsonObject.addProperty("n", count) + jsonObject.addProperty("size", "${resolution}x$resolution") + request.entity = StringEntity(jsonObject.toString()) + val response = post(request) + val jsonObject2 = Gson().fromJson(response, JsonObject::class.java) + if (jsonObject2.has("error")) { + val errorObject = jsonObject2.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + val dataArray = jsonObject2.getAsJsonArray("data") + val images = ArrayList() + for (i in 0 until dataArray.size()) { + images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString))) + } + images + } } + + @Throws(IOException::class) + private fun processCompletionResponse(result: String): CompletionResponse { + checkError(result) + val response = mapper.readValue( + result, + CompletionResponse::class.java + ) + if (response.usage != null) { + incrementTokens(response.usage!!.total_tokens) + } + return response + } + + @Throws(IOException::class) + protected fun processChatResponse(result: String): ChatResponse { + checkError(result) + val response = mapper.readValue( + result, + ChatResponse::class.java + ) + if (response.usage != null) { + incrementTokens(response.usage!!.total_tokens) + } + return response + } + + private fun checkError(result: String) { + try { + val jsonObject = Gson().fromJson( + result, + JsonObject::class.java + ) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + val errorMessage = errorObject["message"].asString + if (errorMessage.startsWith("That model is currently overloaded with other requests.")) { + throw RequestOverloadException(errorMessage) + } + maxTokenErrorMessage.find { it.matcher(errorMessage).matches() }?.let { + val matcher = it.matcher(errorMessage) + if (matcher.find()) { + val modelMax = matcher.group(1).toInt() + val request = matcher.group(2).toInt() + val messages = matcher.group(3).toInt() + val completion = matcher.group(4).toInt() + throw ModelMaxException(modelMax, request, messages, completion) + } + } + throw IOException(errorMessage) + } + } catch (e: com.google.gson.JsonSyntaxException) { + throw IOException("Invalid JSON response: $result") + } + } + + class RequestOverloadException(message: String = "That model is currently overloaded with other requests.") : + IOException(message) + + open fun incrementTokens(totalTokens: Int) {} + + companion object { + val log = Logger.getInstance(OpenAIClient::class.java) + + fun log(level: LogLevel, msg: String) { + val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") + when (level) { + LogLevel.Error -> log.error(message) + LogLevel.Warn -> log.warn(message) + LogLevel.Info -> log.info(message) + else -> log.debug(message) + } + } + + val mapper: ObjectMapper + get() { + val mapper = ObjectMapper() + mapper + .enable(SerializationFeature.INDENT_OUTPUT) + .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) + .enable(MapperFeature.USE_STD_BEAN_NAMING) + .setSerializationInclusion(JsonInclude.Include.NON_NULL) + .activateDefaultTyping(mapper.polymorphicTypeValidator) + return mapper + } + val allowedCharset = Charset.forName("ASCII") + private val maxTokenErrorMessage = listOf( + Pattern.compile( + """This model's maximum context length is (\d+) tokens. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\).*""" + ), + // This model's maximum context length is 4097 tokens, however you requested 80052 tokens (52 in your prompt; 80000 for the completion). Please reduce your prompt; or completion length. + Pattern.compile( + """This model's maximum context length is (\d+) tokens, however you requested (\d+) tokens \((\d+) in your prompt; (\d+) for the completion\).*""" + ) + ) + + } + + fun complete( + completionRequest: CompletionRequest, + model: String + ): CompletionResponse = withReliability { + withPerformanceLogging { + logStart(completionRequest) + val completionResponse = try { + val request: String = + StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + val result = + post(apiBase + "/engines/" + model + "/completions", request) + processCompletionResponse(result) + } catch (e: ModelMaxException) { + completionRequest.max_tokens = (e.modelMax - e.messages) - 1 + val request: String = + StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + val result = + post(apiBase + "/engines/" + model + "/completions", request) + processCompletionResponse(result) + } + val completionResult = StringTools.stripPrefix( + completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' }, + completionRequest.prompt.trim { it <= ' ' }) + logComplete(completionResult) + completionResponse + } + } + + fun chat( + completionRequest: ChatRequest + ): ChatResponse = withReliability { + withPerformanceLogging { + logStart(completionRequest) + val url = apiBase + "/chat/completions" + val completionResponse = try { + processChatResponse( + post( + url, StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + ) + ) + } catch (e: ModelMaxException) { + completionRequest.max_tokens = (e.modelMax - e.messages) - 1 + processChatResponse( + post( + url, StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + ) + ) + } + logComplete(completionResponse.choices.first().message!!.content!!.trim { it <= ' ' }) + completionResponse + } + } + + private fun logStart(completionRequest: ChatRequest) { + log( + logLevel, String.format( + "Chat Request\nPrefix:\n\t%s\n", + mapper.writeValueAsString(completionRequest).replace("\n", "\n\t") + ) + ) + } + + fun moderate(text: String) = withReliability { + withPerformanceLogging { + val body: String = try { + mapper.writeValueAsString( + Map.of( + "input", + StringTools.restrictCharacterSet(text, allowedCharset) + ) + ) + } catch (e: JsonProcessingException) { + throw RuntimeException(e) + } + val result: String = try { + this.post(apiBase + "/moderations", body) + } catch (e: IOException) { + throw RuntimeException(e) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + val jsonObject = + Gson().fromJson( + result, + JsonObject::class.java + ) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + val moderationResult = + jsonObject.getAsJsonArray("results")[0].asJsonObject + log( + LogLevel.Debug, + String.format( + "Moderation Request\nText:\n%s\n\nResult:\n%s", + text.replace("\n", "\n\t"), + result + ) + ) + if (moderationResult["flagged"].asBoolean) { + val categoriesObj = + moderationResult["categories"].asJsonObject + throw RuntimeException( + ModerationException( + "Moderation flagged this request due to " + categoriesObj.keySet() + .stream().filter { c: String? -> + categoriesObj[c].asBoolean + }.reduce { a: String, b: String -> "$a, $b" } + .orElse("???") + ) + ) + } + } + } + + fun edit( + editRequest: EditRequest + ): CompletionResponse = withReliability { + withPerformanceLogging { + logStart(editRequest, logLevel) + val request: String = + StringTools.restrictCharacterSet( + OpenAIClient.mapper.writeValueAsString(editRequest), + allowedCharset + ) + val result = post(apiBase + "/edits", request) + val completionResponse = processCompletionResponse(result) + logComplete( + completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' } + ) + completionResponse + } + } + + private fun logStart( + editRequest: EditRequest, + level: LogLevel + ) { + if (editRequest.input == null) { + log( + level, String.format( + "Text Edit Request\nInstruction:\n\t%s\n", + editRequest.instruction.replace("\n", "\n\t") + ) + ) + } else { + log( + level, String.format( + "Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n", + editRequest.instruction.replace("\n", "\n\t"), + editRequest.input!!.replace("\n", "\n\t") + ) + ) + } + } + +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Response.kt b/src/main/kotlin/com/github/simiacryptus/openai/Response.kt similarity index 83% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Response.kt rename to src/main/kotlin/com/github/simiacryptus/openai/Response.kt index db648da5..65570819 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Response.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/Response.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class Response { @Suppress("unused") diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Usage.kt b/src/main/kotlin/com/github/simiacryptus/openai/Usage.kt similarity index 76% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Usage.kt rename to src/main/kotlin/com/github/simiacryptus/openai/Usage.kt index 41ce403a..e0b748b6 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/core/Usage.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/Usage.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.openai.core +package com.github.simiacryptus.openai class Usage @Suppress("unused") constructor() { @Suppress("unused") diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt new file mode 100644 index 00000000..161fae41 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt @@ -0,0 +1,109 @@ +package com.github.simiacryptus.openai.proxy + +import com.github.simiacryptus.openai.ChatMessage +import com.github.simiacryptus.openai.ChatRequest +import com.github.simiacryptus.openai.OpenAIClient +import com.jetbrains.rd.util.LogLevel +import java.util.concurrent.atomic.AtomicInteger + +@Suppress("MemberVisibilityCanBePrivate") +class ChatProxy( + apiKey: String, + var model: String = "gpt-3.5-turbo", + var maxTokens: Int = 3500, + var temperature: Double = 0.7, + var verbose: Boolean = false, + private val moderated: Boolean = true, + base: String = "https://api.openai.com/v1", + apiLog: String? = null, + logLevel: LogLevel +) : GPTProxyBase(apiLog, 3) { + val api: OpenAIClient + val totalPrefixLength = AtomicInteger(0) + val totalSuffixLength = AtomicInteger(0) + val totalInputLength = AtomicInteger(0) + val totalOutputLength = AtomicInteger(0) + + init { + api = OpenAIClient(base, apiKey, logLevel) + } + + override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { + if (verbose) println(prompt) + val request = ChatRequest() + request.messages = ( + listOf( + ChatMessage( + ChatMessage.Role.system, """ + |You are a JSON-RPC Service + |Responses are expected to be a single JSON object without explaining text. + |All input arguments are optional + |You will respond to the following method: + | + |${prompt.apiYaml} + |""".trimMargin().trim() + ) + ) + + examples.flatMap { + listOf( + ChatMessage( + ChatMessage.Role.user, + argsToString(it.argList) + ), + ChatMessage( + ChatMessage.Role.assistant, + it.response + ) + ) + } + + listOf( + ChatMessage( + ChatMessage.Role.user, + argsToString(prompt.argList) + ) + ) + ).toTypedArray() + request.model = model + request.max_tokens = maxTokens + request.temperature = temperature + val json = toJson(request) + if (moderated) api.moderate(json) + totalInputLength.addAndGet(json.length) + + val completion = api.chat(request).response.get().toString() + if (verbose) println(completion) + totalOutputLength.addAndGet(completion.length) + val trimPrefix = trimPrefix(completion) + val trimSuffix = trimSuffix(trimPrefix.first) + totalPrefixLength.addAndGet(trimPrefix.second.length) + totalSuffixLength.addAndGet(trimSuffix.second.length) + return trimSuffix.first + } + + companion object { + private fun trimPrefix(completion: String): Pair { + val start = completion.indexOf('{').coerceAtMost(completion.indexOf('[')) + if (start < 0) { + return completion to "" + } else { + val substring = completion.substring(start) + return substring to completion.substring(0, start) + } + } + + private fun trimSuffix(completion: String): Pair { + val end = completion.lastIndexOf('}').coerceAtLeast(completion.lastIndexOf(']')) + if (end < 0) { + return completion to "" + } else { + val substring = completion.substring(0, end + 1) + return substring to completion.substring(end + 1) + } + } + + private fun argsToString(argList: Map) = + "{" + argList.entries.joinToString(",\n", transform = { (argName, argValue) -> + """"$argName": $argValue""" + }) + "}" + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt new file mode 100644 index 00000000..c1a7c5c7 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt @@ -0,0 +1,47 @@ +package com.github.simiacryptus.openai.proxy + +import com.github.simiacryptus.openai.CompletionRequest +import com.github.simiacryptus.openai.OpenAIClient +import com.jetbrains.rd.util.LogLevel + +class CompletionProxy( + apiKey: String, + private val model: String = "text-davinci-003", + private val maxTokens: Int = 4000, + private val temperature: Double = 0.7, + private val verbose: Boolean = false, + private val moderated: Boolean = true, + base: String = "https://api.openai.com/v1", + apiLog: String +) : GPTProxyBase(apiLog, 3) { + val api: OpenAIClient + + init { + api = OpenAIClient(base, apiKey, LogLevel.Debug) + } + + override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { + if(verbose) println(prompt) + val request = CompletionRequest() + request.prompt = """ + |Method: ${prompt.methodName} + |Response Type: + | ${prompt.apiYaml.replace("\n", "\n ")} + |Request: + | { + | ${ + prompt.argList.entries.joinToString(",\n", transform = { (argName, argValue) -> + """"$argName": $argValue""" + }).replace("\n", "\n ") + } + | } + |Response: + | {""".trim().trimIndent() + request.max_tokens = maxTokens + request.temperature = temperature + if (moderated) api.moderate(toJson(request)) + val completion = api.complete(request, model).firstChoice.get().toString() + if(verbose) println(completion) + return "{$completion" + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt new file mode 100644 index 00000000..3c6643df --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt @@ -0,0 +1,4 @@ +package com.github.simiacryptus.openai.proxy + +@Retention(AnnotationRetention.RUNTIME) +annotation class Description(val value: String) \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt new file mode 100644 index 00000000..b05b0445 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt @@ -0,0 +1,240 @@ +package com.github.simiacryptus.openai.proxy + +import com.fasterxml.jackson.core.json.JsonReadFeature +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.kotlin.isKotlinClass +import com.google.common.reflect.TypeToken +import java.io.BufferedWriter +import java.io.File +import java.io.FileWriter +import java.lang.reflect.* +import kotlin.reflect.KProperty1 +import kotlin.reflect.full.memberProperties +import kotlin.reflect.jvm.javaType + + +abstract class GPTProxyBase( + apiLogFile: String?, + private val deserializerRetries: Int = 5 +) { + abstract fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String + + fun create(clazz: Class): T { + return Proxy.newProxyInstance(clazz.classLoader, arrayOf(clazz)) { proxy, method, args -> + if (method.name == "toString") return@newProxyInstance clazz.simpleName + val type = method.genericReturnType + val typeString = method.toYaml().trimIndent() + val prompt = ProxyRequest( + method.name, + typeString, + (args ?: arrayOf()).zip(method.parameters) + .filter> { (arg: Any?, _) -> arg != null } + .map, Pair> { (arg, param) -> + param.name to toJson(arg!!) + }.toMap() + ) + + var lastException: Exception? = null + for (retry in 0 until deserializerRetries) { + var result = complete(prompt, *examples[method.name]?.toTypedArray() ?: arrayOf()) + // If the requested `type` is a list, check that result is a list + if (type is ParameterizedType && List::class.java.isAssignableFrom(type.rawType as Class<*>) && !result.startsWith( + "[" + ) + ) { + result = "[$result]" + } + writeToJsonLog(ProxyRecord(method.name, prompt.argList, result)) + try { + val obj = fromJson(type, result) + if (obj is ValidatedObject && !obj.validate()) { + log.warn("Invalid response: $result") + continue + } + return@newProxyInstance obj + } catch (e: Exception) { + log.warn("Failed to parse response: $result", e) + lastException = e + log.info("Retry $retry of $deserializerRetries") + } + } + throw RuntimeException("Failed to parse response", lastException) + } as T + } + + + private val apiLog = apiLogFile?.let { openApiLog(it) } + private val examples = HashMap>() + private fun loadExamples(file: File = File("api.examples.json")): List { + if (!file.exists()) return listOf() + val json = file.readText() + return fromJson(object : ArrayList() {}.javaClass, json) + } + + fun addExamples(file: File) { + examples.putAll(loadExamples(file).groupBy { it.methodName }) + } + + private fun openApiLog(file: String): BufferedWriter { + val writer = BufferedWriter(FileWriter(File(file))) + writer.write("[") + writer.newLine() + writer.flush() + return writer + } + + private fun writeToJsonLog(record: ProxyRecord) { + if (apiLog != null) { + apiLog.write(toJson(record)) + apiLog.write(",") + apiLog.newLine() + apiLog.flush() + } + } + + open fun toJson(data: Any): String { + return objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(data) + } + + open fun fromJson(type: Type, data: String): T { + if (type is Class<*> && type.isAssignableFrom(String::class.java)) return data as T + return objectMapper().readValue(data, objectMapper().typeFactory.constructType(type)) as T + } + + open fun objectMapper(): ObjectMapper { + return ObjectMapper() + .enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS.mappedFeature()) + .configure(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + } + + data class ProxyRequest( + val methodName: String = "", + val apiYaml: String = "", + val argList: Map = mapOf() + ) + + data class ProxyRecord( + val methodName: String = "", + val argList: Map = mapOf(), + val response: String = "" + ) + + companion object { + val log = org.slf4j.LoggerFactory.getLogger(GPTProxyBase::class.java) + + fun Parameter.toYaml(): String { + val description = getAnnotation(Description::class.java)?.value + val yaml = if (description != null) { + """ + |- name: ${this.name} + | description: $description + | ${this.parameterizedType.toYaml().replace("\n", "\n ")} + |""".trimMargin().trim() + } else { + """ + |- name: ${this.name} + | ${this.parameterizedType.toYaml().replace("\n", "\n ")} + |""".trimMargin().trim() + } + return yaml + } + + + fun Type.toYaml(): String { + val typeName = this.typeName.substringAfterLast('.').replace('$', '.').toLowerCase() + val yaml = if (typeName in setOf("boolean", "integer", "number", "string")) { + "type: $typeName" + } else if (this is ParameterizedType && List::class.java.isAssignableFrom(this.rawType as Class<*>)) { + """ + |type: array + |items: + | ${this.actualTypeArguments[0].toYaml().replace("\n", "\n ")} + |""".trimMargin() + } else if (this.isArray) { + """ + |type: array + |items: + | ${this.componentType?.toYaml()?.replace("\n", "\n ")} + |""".trimMargin() + } else { + val rawType = TypeToken.of(this).rawType + val declaredFieldYaml = rawType.declaredFields.map { + """ + |${it.name}: + | ${it.genericType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + }.toTypedArray() + val propertiesYaml = if (rawType.isKotlinClass() && rawType.kotlin.isData) { + rawType.kotlin.memberProperties.map { + val allAnnotations = + getAllAnnotations(rawType, it) + val description = allAnnotations.find { x -> x is Description } as? Description + // Find annotation on the kotlin data class constructor parameter + val yaml = if (description != null) { + """ + |${it.name}: + | description: ${description.value} + | ${it.returnType.javaType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + } else { + """ + |${it.name}: + | ${it.returnType.javaType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + } + yaml + }.toTypedArray() + } else { + arrayOf() + } + val fieldsYaml = (declaredFieldYaml.toList() + propertiesYaml.toList()).distinct().joinToString("\n") + """ + |type: object + |properties: + | ${fieldsYaml.replace("\n", "\n ")} + """.trimMargin() + } + return yaml + } + + private fun getAllAnnotations( + rawType: Class, + property: KProperty1 + ) = property.annotations + (rawType.kotlin.constructors.first().parameters.find { x -> x.name == property.name }?.annotations + ?: listOf()) + + fun Method.toYaml(): String { + val parameterYaml = parameters.map { it.toYaml() }.toTypedArray().joinToString("").trim() + val returnTypeYaml = genericReturnType.toYaml().trim() + val responseYaml = """ + |responses: + | application/json: + | schema: + | ${returnTypeYaml.replace("\n", "\n ")} + """.trimMargin().trim() + val yaml = """ + |operationId: ${"${declaringClass.simpleName}.$name"} + |parameters: + | ${parameterYaml.replace("\n", "\n ")} + |$responseYaml + """.trimMargin() + return yaml + } + + val Type.isArray: Boolean + get() { + return this is Class<*> && this.isArray + } + + val Type.componentType: Type? + get() { + return when (this) { + is Class<*> -> if (this.isArray) this.componentType else null + is ParameterizedType -> this.actualTypeArguments.firstOrNull() + else -> null + } + } + } + +} + diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt new file mode 100644 index 00000000..b12b5818 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt @@ -0,0 +1,26 @@ +package com.github.simiacryptus.openai.proxy + +import kotlin.reflect.full.memberProperties + +interface ValidatedObject { + fun validate(): Boolean = validateFields(this) + + companion object { + fun validateFields(obj: Any): Boolean { + obj.javaClass.declaredFields.forEach { field -> + field.isAccessible = true + val value = field.get(obj) + if (value is ValidatedObject && !value.validate()) { + return false + } + } + obj.javaClass.kotlin.memberProperties.forEach { property -> + val value = property.getter.call(obj) + if (value is ValidatedObject && !value.validate()) { + return false + } + } + return true + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/util/StringTools.kt b/src/main/kotlin/com/github/simiacryptus/util/StringTools.kt similarity index 97% rename from src/main/kotlin/com/github/simiacryptus/aicoder/util/StringTools.kt rename to src/main/kotlin/com/github/simiacryptus/util/StringTools.kt index 925b8b09..659e9e6e 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/util/StringTools.kt +++ b/src/main/kotlin/com/github/simiacryptus/util/StringTools.kt @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.util +package com.github.simiacryptus.util import java.nio.charset.Charset import java.util.* diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 19371b3b..ed1daff0 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -133,6 +133,16 @@ + + + + + + diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/ProxyPlay.ws.kts b/src/test/kotlin/com/github/simiacryptus/aicoder/ProxyPlay.ws.kts deleted file mode 100644 index 96699fbd..00000000 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/ProxyPlay.ws.kts +++ /dev/null @@ -1,22 +0,0 @@ -import com.github.simiacryptus.aicoder.proxy.ProxyTest -import com.github.simiacryptus.aicoder.openai.proxy.ChatProxy -import com.github.simiacryptus.aicoder.openai.proxy.CompletionProxy -import com.intellij.openapi.util.io.FileUtil -import java.io.File - -val keyFile = File("C:\\Users\\andre\\code\\all-projects\\openai.key") -val chatProxy = ChatProxy(apiKey = FileUtil.loadFile(keyFile).trim(), apiLog = "api.log.json") -val completionProxy = CompletionProxy( - apiKey = FileUtil.loadFile(keyFile).trim(), - apiLog = FileUtil.loadFile(keyFile).trim() -) - -println(completionProxy.api.getEngines().joinToString("\n")) -val statement = "The meaning of life is to live a life of meaning." -val proxyFactory = chatProxy -val proxy = proxyFactory.create(ProxyTest.EssayAPI::class.java) -val essayOutline = proxy.essayOutline( - ProxyTest.EssayAPI.Thesis(statement), - "5000 words" -) -println(essayOutline.introduction!!.thesis.statement) diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AlternateHistorySimulator.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AlternateHistorySimulator.kt new file mode 100644 index 00000000..49ccb87c --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AlternateHistorySimulator.kt @@ -0,0 +1,91 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * Simulate world nations interacting via events over time, + * and describe the changing world over time as a series of world news articles. + * (e.g. an alternate history of Europe from 1900 to 1950) + * + * Include: + * - events resulting from interactions with each other + * - random events + * - events caused by the passage of time + * - a description of the effects of each event + * - model political factions within each country, including the policies of each and how much control each has + * + * Optimize to reduce the total size (in serialized bytes) of each api call while ensuring each call has all the needed information to generate the response + */ +class AlternateHistorySimulator : GenerationReportBase(){ + @Test + fun generateAlternateHistory() { + runReport("AlternateHistory", AlternateHistory::class) { api, logJson, out -> + val initialWorldState = api.setupWorld( + startYear = 1900, + endYear = 1950, + nations = listOf("United Kingdom", "France", "Germany", "Russia", "Austria-Hungary", "Italy", "Ottoman Empire") + ) + logJson(initialWorldState) + var currentWorldState = initialWorldState + + for (year in initialWorldState.startYear..initialWorldState.endYear) { + val events = api.generateYearlyEvents(currentWorldState, year) + logJson(events) + val updatedWorldState = api.updateWorldState(currentWorldState, events.events) + logJson(updatedWorldState) + + out("Year: $year") + events.events.forEach { event -> + val eventDescription = api.describeEvent(event, updatedWorldState) + out("- $eventDescription") + } + currentWorldState = updatedWorldState + } + } + } + + interface AlternateHistory { + fun setupWorld( + startYear: Int, + endYear: Int, + nations: List + ): WorldState + + data class WorldState( + val startYear: Int = 0, + val endYear: Int = 0, + val nations: List = listOf(), + ) + + data class Nation( + val name: String = "", + val politicalFactions: List = listOf(), + val territory: List = listOf(), + val economy: String = "", + val military: String = "", + val diplomacy: Map = mapOf(), + val technology: String = "", + ) + + data class PoliticalFaction( + val name: String = "", + val policies: List = listOf(), + val influence: Int = 0, + ) + + fun generateYearlyEvents(worldState: WorldState, year: Int): Events + + data class Events( + val events: List = listOf(), + ) + fun updateWorldState(worldState: WorldState, events: List): WorldState + fun describeEvent(event: Event, worldState: WorldState): String + + data class Event( + val type: String = "", + val nation: String = "", + val details: Map = mapOf(), + ) + } + +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt index aec5851d..ec22cb1f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt @@ -1,9 +1,10 @@ package com.github.simiacryptus.aicoder.proxy -import com.github.simiacryptus.aicoder.openai.proxy.SoftwareProjectAI +import com.github.simiacryptus.aicoder.SoftwareProjectAI +import com.github.simiacryptus.aicoder.SoftwareProjectAI.Companion.parallelImplement +import com.github.simiacryptus.aicoder.SoftwareProjectAI.Companion.write import org.junit.Test -import java.util.zip.ZipEntry -import java.util.zip.ZipOutputStream +import java.util.* /** * AutoDevelop takes a software project description and generates a software project with all the necessary files. @@ -17,85 +18,208 @@ class AutoDevelop : GenerationReportBase() { } } + @Test + fun testMethodImplementation() { + runReport("SoftwareProject_Impl", SoftwareProjectAI::class) { api, logJson, out -> + val sourceCode = api.implementComponentSpecification( + proxy.fromJson( + SoftwareProjectAI.Project::class.java, + "{\r\n \"name\" : \"SupportAliasBot\",\r\n \"description\" : \"Slack bot to monitor a support alias\",\r\n \"language\" : \"Kotlin\",\r\n \"features\" : [ \"Record all requests tagging an alias in a database\", \"Send a message to a Slack channel when requests are tagged with a specific label\" ],\r\n \"libraries\" : [ \"Gradle\", \"Spring\" ],\r\n \"buildTools\" : [ \"Gradle\" ]\r\n}" + ), + proxy.fromJson( + SoftwareProjectAI.ComponentDetails::class.java, + "{\r\n \"description\" : \"Main class for the SupportAliasBot\",\r\n \"requires\" : [ ],\r\n \"publicProperties\" : [ ],\r\n \"publicMethodSignatures\" : [ \"fun main(args: Array): Unit\" ],\r\n \"language\" : \"Kotlin\",\r\n \"location\" : {\r\n \"file\" : \"src/main/kotlin/com/example/supportaliasbot/SupportAliasBot.kt\"\r\n }\r\n}" + ), + listOf(), + proxy.fromJson( + SoftwareProjectAI.CodeSpecification::class.java, + "{\r\n \"name\" : \"SupportAliasBot\",\r\n \"description\" : \"Slack bot to monitor a support alias\",\r\n \"language\" : \"Kotlin\",\r\n \"features\" : [ \"Record all requests tagging an alias in a database\", \"Send a message to a Slack channel when requests are tagged with a specific label\" ],\r\n \"libraries\" : [ \"Gradle\", \"Spring\" ],\r\n \"buildTools\" : [ \"Gradle\" ]\r\n}" + ), + ) + out(""" + |```${sourceCode.language} + |${sourceCode.code} + |``` + |""".trimMargin()) + } + } + private fun report( api: SoftwareProjectAI, - logJson: (Any) -> Unit, - out: (Any) -> Unit + logJson: (Any?) -> Unit, + out: (Any?) -> Unit ) { - val project = api.newProject( - """ - | - |Slack bot to monitor a support alias and automatically respond to common questions - | - |Language: Kotlin - | - """.trimMargin().trim() - ) - logJson(project) + val drafts = 2 + val threads = 7 + proxy.temperature = 0.5 + + @Suppress("JoinDeclarationAndAssignment") + val description: String + description = """ + | + |Slack bot to monitor a support alias + |All requests tagging an alias are recorded in a database + |When requests are tagged with a specific label, the bot will send a message to a slack channel + |Fully implement all functions + |Do not comment code + |Include documentation and build scripts + | + |Language: Kotlin + |Frameworks: Gradle, Spring + | + """.trimMargin() + """ + | + |Create a website where users can upload stories, share them, and rate them + | + |Fully implement all functions + |Do not comment code + |Include documentation and build scripts + | + |Language: Kotlin + |Frameworks: Gradle, Spring + | + """.trimMargin() out( """ - | - |# ${project.name} - | - |${project.description} - | - |Language: ${project.language} - | - |Libraries: ${project.libraries.joinToString(", ")} - | - |Build Tools: ${project.buildTools.joinToString(", ")} - | - |""".trimMargin() + | + |# Software Project Development Report + | + |## Description + | + |``` + |${description.trim()}} + |``` + | + |""".trimMargin() ) - val requirements = api.getProjectStatements(project) - logJson(requirements) - val projectDesign = api.buildProjectDesign(project, requirements) - logJson(projectDesign) - val files = api.buildProjectFileSpecifications(project, requirements, projectDesign) - logJson(files) + var project: SoftwareProjectAI.Project? = null + var requirements: SoftwareProjectAI.ProjectStatements? = null + var projectDesign: SoftwareProjectAI.ProjectDesign? = null + var components: Map>? = null + var documents: Map>? = + null + var tests: Map>? = null + var implementations: Map? = null + try { + project = api.newProject(description.trim()) + out( + """ + | + |Project Name: ${project.name} + | + |Description: ${project.description} + | + |Language: ${project.language} + | + |Libraries: ${project.libraries?.joinToString(", ")} + | + |Build Tools: ${project.buildTools?.joinToString(", ")} + | + |""".trimMargin() + ) + logJson(project) + requirements = api.getProjectStatements(description.trim(), project) + out( + """ + | + |## Requirements + | + |""".trimMargin() + ) + logJson(requirements) - val zipArchiveFile = outputDir.resolve("projects/${project.name}.zip") - zipArchiveFile.parentFile.mkdirs() - out( - """ + projectDesign = api.buildProjectDesign(project, requirements) + out( + """ + | + |## Design + | + |""".trimMargin() + ) + logJson(projectDesign) + components = + projectDesign.components?.map { + it to (api.buildComponentFileSpecifications( + project, + requirements, + it + )) + }?.toMap() + out( + """ + | + |## Components + | + |""".trimMargin() + ) + logJson(components) + + documents = + projectDesign.documents?.map { + it to (api.buildDocumentationFileSpecifications( + project, + requirements, + it + ) + ) + }?.toMap() + out( + """ + | + |## Documents + | + |""".trimMargin() + ) + logJson(documents) + tests = projectDesign.tests?.map { + it to (api.buildTestFileSpecifications(project, requirements, it)) + }?.toMap() + out( + """ + | + |## Tests + | + |""".trimMargin() + ) + logJson(tests) + implementations = parallelImplement(api, project, components, documents, tests, drafts, threads) + } catch (e: Exception) { + e.printStackTrace() + } + + if (implementations != null) { + val name = (project?.name ?: UUID.randomUUID().toString()).replace(Regex("[^a-zA-Z0-9]"), "_") + val relative = "projects/$name.zip" + val zipArchiveFile = outputDir.resolve(relative) + zipArchiveFile.parentFile.mkdirs() + write(zipArchiveFile, implementations) + out( + """ + | + |## Project Files + | + |[Download]($relative) + | + |""".trimMargin() + ) + implementations.toList().sortedBy { it.first.file }.forEach { (file, sourceCodes) -> + out( + """ | - |## Project Files + |### ${file.file} | - |[Download](projects/${project.name}.zip) + |```${sourceCodes!!.language?.lowercase()} + |${sourceCodes.code} + |``` | |""".trimMargin() - ) - ZipOutputStream(zipArchiveFile.outputStream()).use { zip -> - for (file in files.files) { - } - for (file in files.files) { - val sourceCode = api.implement( - project, - files.files.map { it.location }.filter { file.requires.contains(it) }.toList(), - file - ) - zip.putNextEntry(ZipEntry(file.location.toString())) - zip.write(sourceCode.code.toByteArray()) - zip.closeEntry() - out( - """ - | - |## ${file.location.name}.${file.location.extension} - | - |${file.description} - | - |```${sourceCode.language} - |${sourceCode.code} - |``` - | - |""".trimMargin() ) } } - } - + } } \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt index 051f2eab..6f9154b3 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt @@ -8,6 +8,81 @@ import org.junit.Test * cover a story, and generate a report. */ class AutoNews : GenerationReportBase() { + @Test + fun newsWebsite() { + runReport("News", News::class) { api, logJson, out -> + + val publication = News.Publication( + description = "A humorous celebration of the absurdity of modern life", + tags = listOf("satire", "funny", "irish", "st patrick's day"), + name = "St Patty's Journal", + publishing_date = "2023-03-16 (St Patrick's Day)", + ) + val categories = listOf( + "politics", + "science", + "technology", + "business", + "finance", + ) + logJson(publication) + out( + """ + | + |# ${publication.name} + | + |${publication.description} + | + |Tags: ${publication.tags.joinToString(", ")} + | + |""".trimMargin() + ) + for (category in categories) { + out( + """ + | + |## ${category.capitalize()} + | + |""".trimMargin() + ) + try { + val stories = api.getStories(publication, category) + logJson(stories) + for (story in stories.stories) { + try { + val article = api.coverStory(publication, story) + logJson(article) + out( + """ + | + |### ${article.title} + | + |![${story.image!!.detailedCaption}](${ + writeImage( + proxy.api.render( + story.image.detailedCaption, + resolution = 512 + )[0] + ) + }) + | + |${article.content.joinToString("\n\n")} + | + |Keywords: ${article.keywords} + | + |""".trimMargin() + ) + } catch (e: Throwable) { + e.printStackTrace() + } + } + } catch (e: Throwable) { + e.printStackTrace() + } + } + } + } + interface News { fun getPublication(publicationName: String): Publication @@ -22,7 +97,7 @@ class AutoNews : GenerationReportBase() { fun getStories( publication: Publication, category: String, - storyCount: Int = 3, + storyCount: Int = 5, funny: Boolean = true ): StoryList @@ -60,82 +135,4 @@ class AutoNews : GenerationReportBase() { ) } - - @Test - fun newsWebsite() { - runReport("News", News::class) { api, logJson, out -> - val categories = listOf( - "politics", - "science", - "technology", - "entertainment", - "sports", - "business", - "health", - "travel", - "food", - "weather", - "fashion", - "lifestyle", - "finance", - "education", - "environment", - "religion", - ) - val publication = News.Publication( - description = "A humorous exploration of the unknown", - tags = listOf("satire", "funny", "science fiction", "future"), - name = "Beyond Imagination", - publishing_date = "2023-07-04", - ) - logJson(publication) - out( - """ - | - |# ${publication.name} - | - |${publication.description} - | - |Tags: ${publication.tags.joinToString(", ")} - | - |""".trimMargin() - ) - for (category in categories) { - out( - """ - | - |## ${category.capitalize()} - | - |""".trimMargin() - ) - val stories = api.getStories(publication, category) - logJson(stories) - for (story in stories.stories) { - val article = api.coverStory(publication, story) - logJson(article) - out( - """ - | - |### ${article.title} - | - |![${story.image!!.detailedCaption}](${ - writeImage( - proxy.api.text_to_image( - story.image.detailedCaption, - resolution = 512 - )[0] - ) - }) - | - |${article.content.joinToString("\n\n")} - | - |Keywords: ${article.keywords} - | - |""".trimMargin() - ) - } - } - } - } } - diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt new file mode 100644 index 00000000..d2958f86 --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt @@ -0,0 +1,166 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * ChildrenStory creates an illustrated story for children. + * It first generates a plot, consisting of a protagonist, a setting, and a conflict. + * It teaches a lesson about the theme. + * It then generates a cast of characters, and some plot twists. + * Finally, it generates a series of illustrations, one for each page, along with the text for each page. + */ +class ChildrensStory : GenerationReportBase(){ + interface Story { + + fun generatePlot(theme: String): Plot + + data class Plot( + val protagonist: Character = Character(), + val setting: Setting = Setting(), + val conflict: Conflict = Conflict(), + val theme: String = "", + ) + + data class Character( + val name: String = "", + val age: Int = 0, + val gender: String = "", + val description: String = "", + ) + + data class Setting( + val location: String = "", + val time: String = "", + val description: String = "", + ) + + data class Conflict( + val type: String = "", + val description: String = "", + ) + + fun generateCharacters(plot: Plot): CharacterList + + data class CharacterList( + val characters: List = listOf(), + ) + + fun generateTwists(plot: Plot): TwistList + + data class TwistList( + val twists: List = listOf(), + ) + + data class Twist( + val type: String = "", + val description: String = "", + ) + + fun generateIllustrations(plot: Plot): IllustrationList + + data class IllustrationList( + val illustrations: List = listOf(), + ) + + data class Illustration( + val image: ImageDescription = ImageDescription(), + val text: String = "", + ) + + data class ImageDescription( + val style: String = "", + val subject: String = "", + val background: String = "", + val detailedCaption: String = "", + ) + + } + + @Test + fun childrenStory() { + runReport("Children Story", Story::class) { api, logJson, out -> + val theme = "friendship" + val plot = api.generatePlot(theme) + logJson(plot) + out( + """ + | + |# ${plot.protagonist.name} and the ${plot.theme.capitalize()} + | + |${plot.protagonist.name} is a ${plot.protagonist.age} year old ${plot.protagonist.gender} living in ${plot.setting.location}. + | + |${plot.protagonist.description} + | + |${plot.setting.description} + | + |${plot.protagonist.name} is faced with a ${plot.conflict.type} - ${plot.conflict.description} + | + |""".trimMargin() + ) + val characters = api.generateCharacters(plot).characters + logJson(characters) + out( + """ + | + |## Characters + | + |${characters.joinToString("\n\n") { + """ + | + |### ${it.name} + | + |${it.description} + | + |""".trimMargin() + }} + | + |""".trimMargin() + ) + val twists = api.generateTwists(plot).twists + logJson(twists) + out( + """ + | + |## Plot Twists + | + |${twists.joinToString("\n\n") { + """ + | + |### ${it.type} + | + |${it.description} + | + |""".trimMargin() + }} + | + |""".trimMargin() + ) + val illustrations = api.generateIllustrations(plot).illustrations + logJson(illustrations) + out( + """ + | + |## Story + | + |${illustrations.joinToString("\n\n") { + """ + | + |### ${it.text} + | + |![${it.image.detailedCaption}](${ + writeImage( + proxy.api.render( + it.image.detailedCaption, + resolution = 512 + )[0] + ) + }) + | + |""".trimMargin() + }} + | + |""".trimMargin() + ) + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt index c2210d4f..8359b59f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt @@ -84,7 +84,7 @@ class ComicBook: GenerationReportBase() { | |![${character.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( character.image.detailedCaption, resolution = 512 )[0] @@ -132,7 +132,7 @@ class ComicBook: GenerationReportBase() { | |![${page.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( page.image.detailedCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateJudge.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateSimulator.kt similarity index 61% rename from src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateJudge.kt rename to src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateSimulator.kt index 7cc84815..fd79c638 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateJudge.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/DebateSimulator.kt @@ -1,12 +1,53 @@ package com.github.simiacryptus.aicoder.proxy import org.junit.Test -import java.io.BufferedWriter -import java.io.File -import java.time.LocalDateTime -import java.time.format.DateTimeFormatter -class DebateJudge { +/** + * Simulate several participants debating a topic with a moderator + */ +class DebateSimulator : GenerationReportBase() { + + @Test + fun judgeDebate() { + runReport("Debate", Debate::class) { api, logJson, out -> + val debate = api.newRandomDebate( + topic = "What is the best way to solve a problem?", + participantNames = listOf("Socrates", "Buddha", "Jesus", "Confucius", "Nietzsche"), + ) + logJson(debate) + for (question in debate.questions) { + val argument = api.poseQuestion(debate, question) + val dialog = listOf(argument).toMutableList() + logJson(argument) + api.writeArgumentText(debate, argument).let { spokenText -> + out( + """ + | + |${spokenText.speaker}: ${spokenText.text} + | + |""".trimMargin() + ) + } + debate.participants.map { it.name }.filter { it != argument.speaker }.shuffled().forEach { speaker -> + val rebuttal = + api.rebuttal(debate, question, speaker, Debate.DebateArguments(dialog.takeLast(1))) + logJson(rebuttal) + api.writeArgumentText(debate, rebuttal).let { spokenText -> + out( + """ + | + |${spokenText.speaker}: ${spokenText.text} + | + |""".trimMargin() + ) + } + dialog.add(rebuttal) + } + logJson(api.judgeDebate(debate, Debate.DebateArguments(dialog))) + } + } + } + interface Debate { fun newRandomDebate( participantNames: List = listOf(""), @@ -26,6 +67,14 @@ class DebateJudge { val writingStyle: String = "", ) + /** + * + * A data class representing the judgement of a debate. + * + * @property winner the team judged to have won the debate + * @property reasoning the judge's reasoning for the judgement + * @property pointsAwarded the number of points awarded to the winner + */ data class DebateJudgement( val winner: String = "", val reasoning: String = "", val pointsAwarded: Int = 0 ) @@ -54,75 +103,5 @@ class DebateJudge { ) } - @Test - fun judgeDebate() { - if (!ProxyTest.keyFile.exists()) return - val outputDir = File(".") - val markdownOutputFile = File( - outputDir, "Debate_${LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss"))}.md" - ) - BufferedWriter(markdownOutputFile.writer()).use { writer -> - fun out(s: Any) { - println(s.toString()) - writer.write(s.toString()) - writer.newLine() - } - - val proxy = ProxyTest.chatProxy() - fun logJson( - obj: Any - ) { - out( - """ - |```json - |${proxy.toJson(obj)} - |``` - |""".trimMargin() - ) - } - - val file = File("debate.examples.json") - if (file.exists()) proxy.addExamples(file) - val debateApi = proxy.create(Debate::class.java) - val debate = debateApi.newRandomDebate( -// topic = "What is the secret to a happy life?", - topic = "What is the best way to solve a problem?", - participantNames = listOf("Socrates", "Buddha", "Jesus", "Confucius", "Nietzsche"), - ) - logJson(debate) - for (question in debate.questions) { - val argument = debateApi.poseQuestion(debate, question) - val dialog = listOf(argument).toMutableList() - logJson(argument) - debateApi.writeArgumentText(debate, argument).let { spokenText -> - out( - """ - | - |${spokenText.speaker}: ${spokenText.text} - | - |""".trimMargin() - ) - } - debate.participants.map { it.name }.filter { it != argument.speaker }.shuffled().forEach { speaker -> - val rebuttal = - debateApi.rebuttal(debate, question, speaker, Debate.DebateArguments(dialog.takeLast(1))) - logJson(rebuttal) - debateApi.writeArgumentText(debate, rebuttal).let { spokenText -> - out( - """ - | - |${spokenText.speaker}: ${spokenText.text} - | - |""".trimMargin() - ) - } - dialog.add(rebuttal) - } - logJson(debateApi.judgeDebate(debate, Debate.DebateArguments(dialog))) - } - } - } - } - diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt new file mode 100644 index 00000000..55368e00 --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt @@ -0,0 +1,146 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * Write an episode of "Family Guy", printing the output as a screenplay in markdown format + * + * Include: + * - an overall plot with conflict and character development + * - entertaining dialog with images + * - random events + * - humourous cutscenes, with an image each + * + * Optimize to reduce the total size (in serialized bytes) of each api call while ensuring each call has all the needed information to generate the response + */ +class FamilyGuyWriter : GenerationReportBase(){ + @Test + fun writeEpisode() { + runReport("FamilyGuyEpisode", FamilyGuy::class) { api, logJson, out -> + val episodeInfo = api.generateEpisodeInfo() + logJson(episodeInfo) + + out( + """ + |# Family Guy + |## Episode Title: ${episodeInfo.title} + |### Written by: ${episodeInfo.writer} + |""".trimMargin() + ) + + val acts = api.generateActs(episodeInfo) + logJson(acts) + + for ((actIndex, act) in acts.acts.withIndex()) { + out("\n## Act ${actIndex + 1}\n") + val scenes = api.generateScenes(episodeInfo, act) + logJson(scenes) + + for ((sceneIndex, scene) in scenes.scenes.withIndex()) { + out("\n### Scene ${sceneIndex + 1}: ${scene.location}\n") + val dialogues = api.generateDialogues(episodeInfo, act, scene) + logJson(dialogues) + + for (dialogue in dialogues.dialogues) { + val character = dialogue.character.capitalize() + val text = dialogue.text + out("${character}: $text\n") + + if (dialogue.hasImage) { + val imageCaption = api.generateImageCaption(dialogue) + logJson(imageCaption) + + out( + """ + |![${imageCaption.caption}](${ + writeImage( + proxy.api.render( + imageCaption.caption, + resolution = 512 + )[0] + ) + }) + |""".trimMargin() + ) + } + } + + if (scene.hasCutaway) { + val cutaway = api.generateCutaway(episodeInfo, act, scene) + logJson(cutaway) + + out("\n*Cutaway: ${cutaway.description}*\n") + out( + """ + |![${cutaway.imageCaption}](${ + writeImage( + proxy.api.render( + cutaway.imageCaption, + resolution = 512 + )[0] + ) + }) + |""".trimMargin() + ) + } + } + } + } + } + + interface FamilyGuy { + + fun generateEpisodeInfo(): EpisodeInfo + + data class EpisodeInfo( + val title: String = "", + val writer: String = "", + ) + + fun generateActs(episodeInfo: EpisodeInfo): ActList + + data class ActList( + val acts: List = listOf(), + ) + + data class Act( + val id: Int = 0, + ) + + fun generateScenes(episodeInfo: EpisodeInfo, act: Act): SceneList + + data class SceneList( + val scenes: List = listOf(), + ) + + data class Scene( + val location: String = "", + val hasCutaway: Boolean = false, + ) + + fun generateDialogues(episodeInfo: EpisodeInfo, act: Act, scene: Scene): DialogueList + + data class DialogueList( + val dialogues: List = listOf(), + ) + + data class Dialogue( + val character: String = "", + val text: String = "", + val hasImage: Boolean = false, + ) + + fun generateImageCaption(dialogue: Dialogue): ImageCaption + + data class ImageCaption( + val caption: String = "", + ) + + fun generateCutaway(episodeInfo: EpisodeInfo, act: Act, scene: Scene): Cutaway + + data class Cutaway( + val description: String = "", + val imageCaption: String = "", + ) + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/GenerationReportBase.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/GenerationReportBase.kt index 4f53450a..aa5674ee 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/GenerationReportBase.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/GenerationReportBase.kt @@ -15,20 +15,23 @@ open class GenerationReportBase { }.log.json" ) val outputDir = File("../intellij-aicoder-docs") - fun runReport(prefix: String, kClass: KClass, fn: (T, (Any) -> Unit, (Any) -> Unit) -> Unit) { + fun runReport(prefix: String, kClass: KClass, fn: (T, (Any?) -> Unit, (Any?) -> Unit) -> Unit) { if (!ProxyTest.keyFile.exists()) return val markdownOutputFile = File( outputDir, "${prefix}_${LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss"))}.md" ) BufferedWriter(markdownOutputFile.writer()).use { writer -> - fun out(s: Any) { + fun out(s: Any?) { + if (null == s) return println(s.toString()) writer.write(s.toString()) writer.newLine() + writer.flush() } - fun logJson(obj: Any) { + fun logJson(obj: Any?) { + if (null == obj) return out( """ |```json diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt index ccb938f8..7c4fc49f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt @@ -7,7 +7,7 @@ import java.io.File class ImageTest: GenerationReportBase() { @Test fun imageGenerationTest() { - val image = proxy.api.text_to_image("Hello World")[0] + val image = proxy.api.render("Hello World")[0] // Write the image to a file val file = File.createTempFile("image", ".png") javax.imageio.ImageIO.write(image, "png", file) diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/InternationalEventsSimulator.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/InternationalEventsSimulator.kt new file mode 100644 index 00000000..a12ef90a --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/InternationalEventsSimulator.kt @@ -0,0 +1,73 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * Simulate world nations interacting via news events over time and describe the changing relationships and motivations and causes of each event + */ +class InternationalEventsSimulator : GenerationReportBase() { + interface InternationalEvents { + fun generateInitialWorldState( + nationNames: List = listOf(), + initialRelations: List = listOf(), + ): WorldState + + data class WorldState( + val nations: List = listOf(), + val relations: List = listOf(), + val events: List = listOf(), + ) + + data class Nation( + val name: String = "", + val government: String = "", + val economy: String = "", + val military: String = "", + val culture: String = "", + ) + + data class NationRelation( + val nation1: String = "", + val nation2: String = "", + val relationStatus: String = "", + ) + + data class Event( + val involvedNations: List = listOf(), + val eventType: String = "", + val cause: String = "", + val effect: String = "", + val date: String = "", + ) + + fun generateNewEvent(worldState: WorldState): Event + fun updateWorldState(worldState: WorldState, newEvent: Event): WorldState + + fun describeEvent(event: Event): String + } + + @Test + fun simulateInternationalEvents() { + runReport("International Events", InternationalEvents::class) { api, logJson, out -> + val initialWorldState = api.generateInitialWorldState( + nationNames = listOf("USA", "China", "Russia", "UK", "France", "Germany", "Japan", "India", "Brazil", "Canada"), + ) + logJson(initialWorldState) + var worldState = initialWorldState + + for (i in 1..10) { + val newEvent = api.generateNewEvent(worldState) + logJson(newEvent) + out( + """ + | + |Event $i: ${api.describeEvent(newEvent)} + | + |""".trimMargin() + ) + worldState = api.updateWorldState(worldState, newEvent) + logJson(worldState) + } + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ProxyTest.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ProxyTest.kt index 6457a993..eb9f0a0b 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ProxyTest.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ProxyTest.kt @@ -1,22 +1,31 @@ package com.github.simiacryptus.aicoder.proxy -import com.github.simiacryptus.aicoder.openai.proxy.ChatProxy -import com.github.simiacryptus.aicoder.openai.proxy.CompletionProxy +import com.github.simiacryptus.openai.proxy.ChatProxy +import com.github.simiacryptus.openai.proxy.CompletionProxy import com.intellij.openapi.util.io.FileUtil +import com.jetbrains.rd.util.LogLevel import org.junit.Test import java.io.File +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter class ProxyTest { companion object { val keyFile = File("C:\\Users\\andre\\code\\all-projects\\openai.key") - fun chatProxy(apiLog: String = "api.log.json"): ChatProxy = ChatProxy( - apiKey = FileUtil.loadFile(keyFile).trim(), - apiLog = apiLog - ) + fun chatProxy(apiLog: String = "api.${ + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss")) + }.log.json"): ChatProxy = ChatProxy( + apiKey = FileUtil.loadFile(keyFile).trim(), + apiLog = apiLog, + logLevel = LogLevel.Warn, + model = "gpt-3.5-turbo-0301", + //model = "gpt-4-0314", + maxTokens = 8912 + ) fun completionProxy(apiLog: String = "api.log.json"): CompletionProxy = CompletionProxy( - apiKey = FileUtil.loadFile(keyFile).trim(), - apiLog = apiLog - ) + apiKey = FileUtil.loadFile(keyFile).trim(), + apiLog = apiLog + ) } interface TestAPI { diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt new file mode 100644 index 00000000..c2bd932c --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt @@ -0,0 +1,116 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * RecipeBook builds a recipe book centered around a theme. + * Recipes are organized by main ingredient. + * Each recipe has a title, a list of ingredients, and a list of steps. + * Also included are substitutions for ingredients, a list of related recipes, and cooking tips. + */ +class RecipeBook : GenerationReportBase() { + interface Recipes { + + fun getRecipes( + theme: String, + ingredient: String, + recipeCount: Int = 10 + ): RecipeList + + data class RecipeList( + val recipes: List = listOf(), + val theme: String = "", + val ingredient: String = "", + ) + + data class Recipe( + val title: String = "", + val ingredients: List = listOf(), + val steps: List = listOf(), + val substitutions: List = listOf(), + val relatedRecipes: List = listOf(), + val cookingTips: List = listOf(), + val image: ImageDescription? = null, + ) + + data class ImageDescription( + val style: String = "", + val subject: String = "", + val background: String = "", + val detailedCaption: String = "", + ) + + } + + @Test + fun recipeBook() { + runReport("Recipes", Recipes::class) { api, logJson, out -> + val theme = "Italian" + val ingredients = listOf( + "beef", + "chicken", + "pasta", + "potatoes", + ) + out( + """ + | + |# ${theme.capitalize()} Recipes + | + |""".trimMargin() + ) + for (ingredient in ingredients) { + out( + """ + | + |## Recipes with $ingredient + | + |""".trimMargin() + ) + try { + val recipes = api.getRecipes(theme, ingredient) + logJson(recipes) + for (recipe in recipes.recipes) { + out( + """ + | + |### ${recipe.title} + | + |![${recipe.image!!.detailedCaption}](${ + writeImage( + proxy.api.render( + recipe.image.detailedCaption, + resolution = 512 + )[0] + ) + }) + |Ingredients: + | + |${recipe.ingredients.joinToString("\n")} + | + |Steps: + | + |${recipe.steps.joinToString("\n")} + | + |Substitutions: + | + |${recipe.substitutions.joinToString("\n")} + | + |Related Recipes: + | + |${recipe.relatedRecipes.joinToString("\n")} + | + |Cooking Tips: + | + |${recipe.cookingTips.joinToString("\n")} + | + |""".trimMargin() + ) + } + } catch (e: Throwable) { + e.printStackTrace() + } + } + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/TravelGuide.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/TravelGuide.kt new file mode 100644 index 00000000..4de5de84 --- /dev/null +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/TravelGuide.kt @@ -0,0 +1,60 @@ +package com.github.simiacryptus.aicoder.proxy + +import org.junit.Test + +/** + * TravelGuide builds a travel guide for a given destination. + * The guide includes information about the destination, attractions, restaurants, and activities. + * Also included are tips for getting around, a list of related destinations, and a list of recommended hotels. + */ +class TravelGuide : GenerationReportBase() { + interface Travel { + + fun getDestination(destinationName: String): Destination + + data class Destination( + val name: String = "", + val description: String = "", + val attractions: List = listOf(), + val restaurants: List = listOf(), + val activities: List = listOf(), + val tips: List = listOf(), + val relatedDestinations: List = listOf(), + val recommendedHotels: List = listOf(), + ) + + data class Attraction( + val name: String = "", + val description: String = "", + val image: ImageDescription? = null, + ) + + data class Restaurant( + val name: String = "", + val description: String = "", + val image: ImageDescription? = null, + ) + + data class Activity( + val name: String = "", + val description: String = "", + val image: ImageDescription? = null, + ) + + data class ImageDescription( + val style: String = "", + val subject: String = "", + val background: String = "", + val detailedCaption: String = "", + ) + + } + + @Test + fun travelGuide() { + runReport("Travel", Travel::class) { api, logJson, out -> + + } + } +} + diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt index 3cd1ee49..cfab3eb8 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt @@ -90,7 +90,7 @@ class VideoGame : GenerationReportBase() { | |![${character.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( character.image.detailedCaption, resolution = 512 )[0] @@ -138,7 +138,7 @@ class VideoGame : GenerationReportBase() { | |![${level.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( level.image.detailedCaption, resolution = 512 )[0]