From 66a99e50ff458d4ee408ba17a8f7ddd3295cc2cd Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Mon, 20 Mar 2023 23:34:53 -0400 Subject: [PATCH] 1.0.18 --- CHANGELOG.md | 7 + gradle.properties | 2 +- .../simiacryptus/aicoder/SoftwareProjectAI.kt | 360 ++++++++++++ .../actions/dev/GenerateProjectAction.kt | 30 +- .../aicoder/config/AppSettingsState.kt | 2 + .../aicoder/openai/async/AsyncAPI.kt | 44 +- .../aicoder/openai/async/AsyncAPIImpl.kt | 6 +- .../{CoreAPIImpl.kt => OpenAIClientImpl.kt} | 6 +- .../aicoder/openai/ui/OpenAI_API.kt | 28 +- .../com/github/simiacryptus/openai/CoreAPI.kt | 516 ------------------ .../simiacryptus/openai/HttpClientManager.kt | 138 +++++ .../simiacryptus/openai/OpenAIClient.kt | 472 ++++++++++++++++ .../simiacryptus/openai/proxy/ChatProxy.kt | 38 +- .../openai/proxy/CompletionProxy.kt | 35 +- .../openai/proxy/{Notes.kt => Description.kt} | 2 +- .../simiacryptus/openai/proxy/GPTProxyBase.kt | 317 +++++------ .../openai/proxy/SoftwareProjectAI.kt | 294 ---------- .../openai/proxy/ValidatedObject.kt | 26 + .../simiacryptus/aicoder/proxy/AutoDevelop.kt | 122 +++-- .../simiacryptus/aicoder/proxy/AutoNews.kt | 2 +- .../aicoder/proxy/ChildrensStory.kt | 2 +- .../simiacryptus/aicoder/proxy/ComicBook.kt | 4 +- .../aicoder/proxy/FamilyGuyWriter.kt | 4 +- .../simiacryptus/aicoder/proxy/ImageTest.kt | 2 +- .../simiacryptus/aicoder/proxy/RecipeBook.kt | 2 +- .../simiacryptus/aicoder/proxy/VideoGame.kt | 4 +- 26 files changed, 1312 insertions(+), 1153 deletions(-) create mode 100644 src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt rename src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/{CoreAPIImpl.kt => OpenAIClientImpl.kt} (77%) delete mode 100644 src/main/kotlin/com/github/simiacryptus/openai/CoreAPI.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt rename src/main/kotlin/com/github/simiacryptus/openai/proxy/{Notes.kt => Description.kt} (63%) delete mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/SoftwareProjectAI.kt create mode 100644 src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b7cdae3..558ddc8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ ### Added - +## [1.0.18] + +### Improved +- API stability and performance +- Various bug fixes +- Max tokens handling + ## [1.0.17] ### Added diff --git a/gradle.properties b/gradle.properties index 445fae3c..0cbaaee5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,7 +4,7 @@ pluginGroup = com.github.simiacryptus pluginName = intellij-aicoder pluginRepositoryUrl = https://github.com/SimiaCryptus/intellij-aicoder # SemVer format -> https://semver.org -pluginVersion = 1.0.17 +pluginVersion = 1.0.18 # Supported build number ranges and IntelliJ Platform versions -> https://plugins.jetbrains.com/docs/intellij/build-number-ranges.html pluginSinceBuild = 203 diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt new file mode 100644 index 00000000..a93a5cac --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/SoftwareProjectAI.kt @@ -0,0 +1,360 @@ +package com.github.simiacryptus.aicoder + +import com.github.simiacryptus.openai.proxy.Description +import com.github.simiacryptus.openai.proxy.ValidatedObject +import java.io.File +import java.util.* +import java.util.concurrent.Callable +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.Future +import java.util.concurrent.atomic.AtomicInteger +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream + +interface SoftwareProjectAI { + + fun newProject(description: String): Project + + data class Project( + val name: String? = "", + val description: String? = "", + val language: String? = "", + val features: List? = listOf(), + val libraries: List? = listOf(), + val buildTools: List? = listOf(), + ) : ValidatedObject + + fun getProjectStatements(description: String, project: Project): ProjectStatements + + data class ProjectStatements( + val assumptions: List? = listOf(), + val designPatterns: List? = listOf(), + val requirements: List? = listOf(), + val risks: List? = listOf(), + ) : ValidatedObject + + fun buildProjectDesign(project: Project, requirements: ProjectStatements): ProjectDesign + + data class ProjectDesign( + @Description("Major components e.g. 'core', 'ui', 'api', etc.") + val components: List? = listOf(), + @Description("Documentation files e.g. README.md, LICENSE, etc.") + val documents: List? = listOf(), + @Description("Individual test cases") + val tests: List? = listOf(), + ) : ValidatedObject + + data class ComponentDetails( + val name: String? = "", + val description: String? = "", + val features: List? = listOf(), + ) : ValidatedObject + + data class TestDetails( + val name: String? = "", + val steps: List? = listOf(), + val expectations: List? = listOf(), + ) : ValidatedObject + + data class DocumentationDetails( + val name: String? = "", + val description: String? = "", + val sections: List? = listOf(), + ) : ValidatedObject + + fun buildProjectFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: ProjectDesign, + recursive: Boolean = true + ): List + + fun buildComponentFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: ComponentDetails, + recursive: Boolean = true + ): List + + fun buildTestFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: TestDetails, + recursive: Boolean = true + ): List + + fun buildDocumentationFileSpecifications( + project: Project, + requirements: ProjectStatements, + design: DocumentationDetails, + recursive: Boolean = true + ): List + + data class CodeSpecification( + val description: String? = "", + val requires: List? = listOf(), + val publicProperties: List? = listOf(), + val publicMethodSignatures: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class DocumentSpecification( + val description: String? = "", + val requires: List? = listOf(), + val sections: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class TestSpecification( + val description: String? = "", + val requires: List? = listOf(), + val steps: List? = listOf(), + val expectations: List? = listOf(), + val language: String? = "", + val location: FilePath? = FilePath(), + ) : ValidatedObject + + data class FilePath( + @Description("File name relative to project root, e.g. src/main/java/Foo.java") + val file: String? = "", + ) : ValidatedObject { + override fun toString(): String { + return file ?: "" + } + + override fun validate(): Boolean { + if (file?.isBlank() != false) return false + return super.validate() + } + } + + fun implementComponentSpecification( + project: Project, + component: ComponentDetails, + imports: List, + specification: CodeSpecification, + ): SourceCode + + + fun implementTestSpecification( + project: Project, + specification: TestSpecification, + test: TestDetails, + imports: List, + specificationAgain: TestSpecification, + ): SourceCode + + + fun implementDocumentationSpecification( + project: Project, + specification: DocumentSpecification, + documentation: DocumentationDetails, + imports: List, + specificationAgain: DocumentSpecification, + ): SourceCode + + data class SourceCode( + @Description("language of the code, e.g. \"java\" or \"kotlin\"") + val language: String? = "", + @Description("Fully implemented source code") + val code: String? = "", + ) : ValidatedObject + + companion object { + val log = org.slf4j.LoggerFactory.getLogger(SoftwareProjectAI::class.java) + fun parallelImplement( + api: SoftwareProjectAI, + project: Project, + components: Map>?, + documents: Map>?, + tests: Map>?, + drafts: Int, + threads: Int + ): Map = parallelImplementWithAlternates( + api, + project, + components ?: mapOf(), + documents ?: mapOf(), + tests ?: mapOf(), + drafts, + threads + ).mapValues { it.value.maxByOrNull { it.code?.length ?: 0 } } + + fun parallelImplementWithAlternates( + api: SoftwareProjectAI, + project: Project, + components: Map>, + documents: Map>, + tests: Map>, + drafts: Int, + threads: Int, + progress: (Double) -> Unit = {} + ): Map> { + val threadPool = Executors.newFixedThreadPool(threads) + try { + val totalDrafts = (components + tests + documents).values.sumOf { it.size } * drafts + val currentDraft = AtomicInteger(0) + val fileImplCache = ConcurrentHashMap>>>() + val normalizeFileName: (String?) -> String = { + it?.trimStart('/', '.') ?: "" + } + + // Build Components + fun buildCodeSpec( + component: ComponentDetails, + files: List, + file: CodeSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementComponentSpecification( + project, + component, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildComponentDetails( + component: ComponentDetails, + files: List + ): List>> { + return files.flatMap(fun(file: CodeSpecification): List>> { + return buildCodeSpec(component, files, file) + }).toTypedArray().toList() + } + + val componentFutures = components.flatMap { (component, files) -> + buildComponentDetails(component, files) + }.toTypedArray() + + // Build Documents + fun buildDocumentSpec( + documentation: DocumentationDetails, + files: List, + file: DocumentSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementDocumentationSpecification( + project, + file.copy(requires = listOf()), + documentation, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildDocumentDetails( + documentation: DocumentationDetails, + files: List + ): List>> { + return files.flatMap(fun(file: DocumentSpecification): List>> { + return buildDocumentSpec(documentation, files, file) + }).toTypedArray().toList() + } + + val documentFutures = documents.flatMap { (documentation, files) -> + buildDocumentDetails(documentation, files) + }.toTypedArray() + + // Build Tests + fun buildTestSpec( + test: TestDetails, + files: List, + file: TestSpecification + ): List>> { + if (file.location == null) { + return emptyList() + } + return fileImplCache.getOrPut(normalizeFileName(file.location.file)) { + (0 until drafts).map { _ -> + threadPool.submit(Callable { + val implement = api.implementTestSpecification( + project, + file.copy(requires = listOf()), + test, + files.filter { file.requires?.contains(it.location) ?: false }.toList(), + file.copy(requires = listOf()) + ) + (currentDraft.incrementAndGet().toDouble() / totalDrafts) + .also { progress(it) } + .also { log.info("Progress: $it") } + file.location to implement + }) + } + } + } + + fun buildTestDetails( + test: TestDetails, + files: List + ): List>> { + return files.flatMap(fun(file: TestSpecification): List>> { + return buildTestSpec(test, files, file) + }).toTypedArray().toList() + } + + val testFutures = tests.flatMap { (test, files) -> + buildTestDetails(test, files) + }.toTypedArray() + + return (getAll(componentFutures) + getAll(documentFutures) + getAll(testFutures)).mapValues { + it.value.map { it.second }.sortedBy { it.code?.length ?: 0 } + } + } finally { + threadPool.shutdown() + } + } + + private fun getAll(testFutures: Array>>) = + testFutures.map { + try { + Optional.ofNullable(it.get()) + } catch (e: Throwable) { + Optional.empty() + } + }.filter { !it.isEmpty }.map { it.get() }.groupBy { it.first } + + fun write( + zipArchiveFile: File, + implementations: Map + ) { + ZipOutputStream(zipArchiveFile.outputStream()).use { zip -> + implementations.forEach { (file, sourceCodes) -> + zip.putNextEntry(ZipEntry(file.toString())) + zip.write(sourceCodes!!.code?.toByteArray() ?: byteArrayOf()) + zip.closeEntry() + } + } + } + } + + +} + diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt index c35cbae7..83a445e9 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/actions/dev/GenerateProjectAction.kt @@ -3,7 +3,7 @@ package com.github.simiacryptus.aicoder.actions.dev import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.config.Name import com.github.simiacryptus.aicoder.util.UITools -import com.github.simiacryptus.openai.proxy.SoftwareProjectAI +import com.github.simiacryptus.aicoder.SoftwareProjectAI import com.github.simiacryptus.openai.proxy.ChatProxy import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent @@ -89,26 +89,26 @@ class GenerateProjectAction : AnAction() { UITools.run( e.project, "Specifying Components", true ) { - projectDesign.components.map { it to api.buildComponentFileSpecifications(project, requirements, it) } - .toMap() + projectDesign.components?.map { it to api.buildComponentFileSpecifications(project, requirements, it) } + ?.toMap() } val documents = UITools.run( e.project, "Specifying Documents", true ) { - projectDesign.documents.map { + projectDesign.documents?.map { it to api.buildDocumentationFileSpecifications( project, requirements, it ) - }.toMap() + }?.toMap() } val tests = UITools.run( e.project, "Specifying Tests", true - ) { projectDesign.tests.map { it to api.buildTestFileSpecifications(project, requirements, it) }.toMap() } + ) { projectDesign.tests?.map { it to api.buildTestFileSpecifications(project, requirements, it) }?.toMap() } val sourceCodeMap = UITools.run( e.project, "Implementing Files", true @@ -116,9 +116,9 @@ class GenerateProjectAction : AnAction() { SoftwareProjectAI.parallelImplementWithAlternates( api, project, - components, - documents, - tests, + components ?: emptyMap(), + documents ?: emptyMap(), + tests ?: emptyMap(), config.drafts, AppSettingsState.instance.apiThreads ) { progress -> @@ -129,16 +129,16 @@ class GenerateProjectAction : AnAction() { UITools.run(e.project, "Writing Files", false) { val outputDir = File(selectedFolder.canonicalPath!!) sourceCodeMap.forEach { (file, sourceCode) -> - val relative = file.fullFilePathName - .trimEnd('/') - .trimStart('/', '.') + val relative = file.file + ?.trimEnd('/') + ?.trimStart('/', '.') ?: "" if (File(relative).isRooted) { log.warn("Invalid path: $relative") } else { val outFile = outputDir.resolve(relative) outFile.parentFile.mkdirs() - val best = sourceCode.maxByOrNull { it.code.length }!! - outFile.writeText(best.code) + val best = sourceCode.maxByOrNull { it.code?.length ?: 0 }!! + outFile.writeText(best?.code ?: "") log.debug("Wrote ${outFile.canonicalPath} (Resolved from $relative)") if (config.saveAlternates) for ((index, alternate) in sourceCode.filter { it != best }.withIndex()) { @@ -147,7 +147,7 @@ class GenerateProjectAction : AnAction() { relative + ".${index + 1}" ) outFileAlternate.parentFile.mkdirs() - outFileAlternate.writeText(alternate.code) + outFileAlternate.writeText(alternate?.code ?: "") log.debug("Wrote ${outFileAlternate.canonicalPath} (Resolved from $relative)") } } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt index 2375a2ff..dc7e15b2 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/config/AppSettingsState.kt @@ -85,6 +85,7 @@ class AppSettingsState : PersistentStateComponent { if (apiKey != that.apiKey) return false if (model_completion != that.model_completion) return false if (model_edit != that.model_edit) return false + if (model_chat != that.model_chat) return false if (translationRequestTemplate != that.translationRequestTemplate) return false if (apiLogLevel != that.apiLogLevel) return false if (devActions != that.devActions) return false @@ -97,6 +98,7 @@ class AppSettingsState : PersistentStateComponent { apiKey, model_completion, model_edit, + model_chat, maxTokens, temperature, translationRequestTemplate, diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt index 73d1260f..4983183e 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPI.kt @@ -22,7 +22,7 @@ import java.util.function.Consumer import java.util.stream.Collectors open class AsyncAPI( - val coreAPI: CoreAPI, + val openAIClient: OpenAIClient, private val suppressProgress: Boolean = false ) { @@ -43,33 +43,7 @@ open class AsyncAPI( ) { override fun compute(indicator: ProgressIndicator): CompletionResponse { try { - if (editRequest.input == null) { - log( - settings.apiLogLevel, String.format( - "Text Edit Request\nInstruction:\n\t%s\n", - editRequest.instruction.replace("\n", "\n\t") - ) - ) - } else { - log( - settings.apiLogLevel, String.format( - "Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n", - editRequest.instruction.replace("\n", "\n\t"), - editRequest.input!!.replace("\n", "\n\t") - ) - ) - } - val request: String = - StringTools.restrictCharacterSet( - CoreAPI.mapper.writeValueAsString(editRequest), - allowedCharset - ) - val result = coreAPI.post(settings.apiBase + "/edits", request) - val completionResponse = coreAPI.processCompletionResponse(result) - coreAPI.logComplete( - completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' } - ) - return completionResponse + return openAIClient.edit(editRequest) } catch (e: IOException) { throw RuntimeException(e) } catch (e: InterruptedException) { @@ -123,7 +97,7 @@ open class AsyncAPI( ) threadRef.getAndSet(Thread.currentThread()) try { - return coreAPI.complete(completionRequest, model) + return openAIClient.complete(completionRequest, model) } catch (e: IOException) { log.error(e) throw RuntimeException(e) @@ -160,7 +134,7 @@ open class AsyncAPI( task = object : Task.WithResult, Exception?>(project, "Moderation", false) { override fun compute(indicator: ProgressIndicator): ListenableFuture<*> { return pool.submit { - coreAPI.moderate(text) + openAIClient.moderate(text) } } }, @@ -186,7 +160,7 @@ open class AsyncAPI( newRequest.max_tokens = settings!!.maxTokens newRequest.temperature = settings.temperature newRequest.model = settings.model_chat - return coreAPI.chat(newRequest) + return openAIClient.chat(newRequest) } catch (e: IOException) { throw RuntimeException(e) } catch (e: InterruptedException) { @@ -206,10 +180,10 @@ open class AsyncAPI( fun log(level: LogLevel, msg: String) { val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") when (level) { - LogLevel.Error -> CoreAPI.log.error(message) - LogLevel.Warn -> CoreAPI.log.warn(message) - LogLevel.Info -> CoreAPI.log.info(message) - else -> CoreAPI.log.debug(message) + LogLevel.Error -> OpenAIClient.log.error(message) + LogLevel.Warn -> OpenAIClient.log.warn(message) + LogLevel.Info -> OpenAIClient.log.info(message) + else -> OpenAIClient.log.debug(message) } } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt index 456d0bd9..1822d51d 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/async/AsyncAPIImpl.kt @@ -2,12 +2,12 @@ package com.github.simiacryptus.aicoder.openai.async import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API -import com.github.simiacryptus.openai.CoreAPI +import com.github.simiacryptus.openai.OpenAIClient import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressIndicator import com.jetbrains.rd.util.AtomicReference -class AsyncAPIImpl(core: CoreAPI, appSettingsState: AppSettingsState) : AsyncAPI( +class AsyncAPIImpl(core: OpenAIClient, appSettingsState: AppSettingsState) : AsyncAPI( core, appSettingsState.suppressProgress ) { @@ -17,7 +17,7 @@ class AsyncAPIImpl(core: CoreAPI, appSettingsState: AppSettingsState) : AsyncAPI val thread = threadRef.get() if (null != thread) { thread.interrupt() - coreAPI.closeClient(thread) + openAIClient.closeClient(thread) } } } diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt similarity index 77% rename from src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt rename to src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt index f61c11e7..30c88a5f 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/CoreAPIImpl.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAIClientImpl.kt @@ -1,11 +1,11 @@ package com.github.simiacryptus.aicoder.openai.ui import com.github.simiacryptus.aicoder.config.AppSettingsState -import com.github.simiacryptus.openai.CoreAPI +import com.github.simiacryptus.openai.OpenAIClient -class CoreAPIImpl( +class OpenAIClientImpl( private val appSettingsState: AppSettingsState -) : CoreAPI( +) : OpenAIClient( appSettingsState.apiBase, appSettingsState.apiKey, appSettingsState.apiLogLevel diff --git a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt index ff493c5b..631aac5a 100644 --- a/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt +++ b/src/main/kotlin/com/github/simiacryptus/aicoder/openai/ui/OpenAI_API.kt @@ -1,6 +1,5 @@ package com.github.simiacryptus.aicoder.openai.ui -import com.fasterxml.jackson.databind.node.ObjectNode import com.github.simiacryptus.aicoder.config.AppSettingsState import com.github.simiacryptus.aicoder.openai.async.AsyncAPI import com.github.simiacryptus.aicoder.openai.async.AsyncAPI.Companion.map @@ -26,10 +25,10 @@ object OpenAI_API { private val log = Logger.getInstance(OpenAI_API::class.java) @JvmStatic - val coreAPI: CoreAPI = CoreAPIImpl(settingsState!!) + val openAIClient: OpenAIClient = OpenAIClientImpl(settingsState!!) @JvmStatic - val asyncAPI = AsyncAPIImpl(coreAPI, settingsState!!) + val asyncAPI = AsyncAPIImpl(openAIClient, settingsState!!) private val activeModelUI = WeakHashMap, Any>() @@ -60,16 +59,10 @@ object OpenAI_API { activeModelUI[comboBox] = Any() AsyncAPI.onSuccess( engines - ) { engines: ObjectNode -> - val data = engines["data"] - val items = - arrayOfNulls(data.size()) - for (i in 0 until data.size()) { - items[i] = data[i]["id"].asText() - } - Arrays.sort(items) + ) { engines: Array -> + Arrays.sort(engines) activeModelUI.keys.forEach(Consumer { ui: ComboBox -> - Arrays.stream(items).forEach(ui::addItem) + Arrays.stream(engines).forEach(ui::addItem) }) } return comboBox!! @@ -123,12 +116,9 @@ object OpenAI_API { } return settings } - private val engines: ListenableFuture - get() = AsyncAPI.pool.submit { - CoreAPI.mapper.readValue( - coreAPI.get(settingsState!!.apiBase + "/engines"), - ObjectNode::class.java - ) + private val engines: ListenableFuture> + get() = AsyncAPI.pool.submit> { + openAIClient.getEngines() } fun complete( @@ -185,6 +175,6 @@ object OpenAI_API { } } - fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String = coreAPI.text_to_speech(wavAudio, prompt) + fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String = openAIClient.dictate(wavAudio, prompt) } diff --git a/src/main/kotlin/com/github/simiacryptus/openai/CoreAPI.kt b/src/main/kotlin/com/github/simiacryptus/openai/CoreAPI.kt deleted file mode 100644 index 6fb08c21..00000000 --- a/src/main/kotlin/com/github/simiacryptus/openai/CoreAPI.kt +++ /dev/null @@ -1,516 +0,0 @@ -package com.github.simiacryptus.openai - -import com.fasterxml.jackson.annotation.JsonInclude -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.databind.MapperFeature -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.databind.SerializationFeature -import com.fasterxml.jackson.databind.node.ObjectNode -import com.github.simiacryptus.aicoder.openai.async.AsyncAPI -import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API -import com.github.simiacryptus.util.StringTools -import com.github.simiacryptus.aicoder.util.UITools -import com.google.common.util.concurrent.ListeningScheduledExecutorService -import com.google.common.util.concurrent.MoreExecutors -import com.google.common.util.concurrent.ThreadFactoryBuilder -import com.google.gson.Gson -import com.google.gson.JsonObject -import com.intellij.openapi.diagnostic.Logger -import com.jetbrains.rd.util.LogLevel -import org.apache.http.HttpResponse -import org.apache.http.client.methods.HttpGet -import org.apache.http.client.methods.HttpPost -import org.apache.http.client.methods.HttpRequestBase -import org.apache.http.entity.ContentType -import org.apache.http.entity.StringEntity -import org.apache.http.entity.mime.HttpMultipartMode -import org.apache.http.entity.mime.MultipartEntityBuilder -import org.apache.http.impl.client.CloseableHttpClient -import org.apache.http.impl.client.HttpClientBuilder -import org.apache.http.util.EntityUtils -import java.awt.image.BufferedImage -import java.io.IOException -import java.net.URL -import java.nio.charset.Charset -import java.time.Duration -import java.util.* -import java.util.Map -import java.util.concurrent.ScheduledThreadPoolExecutor -import java.util.concurrent.ThreadFactory -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean -import java.util.regex.Pattern -import javax.imageio.ImageIO -import kotlin.collections.ArrayList -import kotlin.collections.List -import kotlin.collections.MutableMap -import kotlin.collections.first -import kotlin.collections.set - -@Suppress("unused") -open class CoreAPI( - val apiBase: String, - var key: String, - val logLevel: LogLevel -) { - - - fun getEngines(): Array { - val engines = mapper.readValue( - get(OpenAI_API.settingsState!!.apiBase + "/engines"), - ObjectNode::class.java - ) - val data = engines["data"] - val items = - arrayOfNulls(data.size()) - for (i in 0 until data.size()) { - items[i] = data[i]["id"].asText() - } - Arrays.sort(items) - return items - } - - private val clients: MutableMap = WeakHashMap() - - @Throws(IOException::class, InterruptedException::class) - fun post(url: String, body: String): String { - return post(url, body, 3) - } - - - fun logComplete(completionResult: CharSequence) { - log( - logLevel, String.format( - "Chat Completion:\n\t%s", - completionResult.toString().replace("\n", "\n\t") - ) - ) - } - - fun logStart(completionRequest: CompletionRequest) { - if (completionRequest.suffix == null) { - log( - logLevel, String.format( - "Text Completion Request\nPrefix:\n\t%s\n", - completionRequest.prompt.replace("\n", "\n\t") - ) - ) - } else { - log( - logLevel, String.format( - "Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n", - completionRequest.prompt.replace("\n", "\n\t"), - completionRequest.suffix!!.replace("\n", "\n\t") - ) - ) - } - } - - @Throws(IOException::class, InterruptedException::class) - fun post(url: String, json: String, retries: Int): String { - return post(jsonRequest(url, json), retries) - } - - fun post( - request: HttpPost, - retries: Int - ): String { - try { - val client = getClient() - try { - client.use { httpClient -> - synchronized(clients) { - clients[Thread.currentThread()] = httpClient - } - val response: HttpResponse = httpClient.execute(request) - val entity = response.entity - return EntityUtils.toString(entity) - } - } finally { - synchronized(clients) { - clients.remove(Thread.currentThread()) - } - } - } catch (e: IOException) { - if (retries > 0) { - log.warn("Error posting request, retrying in 15 seconds", e) - Thread.sleep(15000) - return post(request, retries - 1) - } - throw e - } - } - - fun jsonRequest(url: String, json: String): HttpPost { - val request = HttpPost(url) - request.addHeader("Content-Type", "application/json") - request.addHeader("Accept", "application/json") - authorize(request) - request.entity = StringEntity(json) - return request - } - - @Throws(IOException::class) - fun authorize(request: HttpRequestBase) { - var apiKey: CharSequence = key - if (apiKey.length == 0) { - synchronized(OpenAI_API.javaClass) { - apiKey = key - if (apiKey.length == 0) { - apiKey = UITools.queryAPIKey()!! - key = apiKey.toString() - } - } - } - request.addHeader("Authorization", "Bearer $apiKey") - } - - /** - * Gets the response from the given URL. - * - * @param url The URL to GET the response from. - * @return The response from the given URL. - * @throws IOException If an I/O error occurs. - */ - @Throws(IOException::class) - operator fun get(url: String?): String { - val client = getClient() - val request = HttpGet(url) - request.addHeader("Content-Type", "application/json") - request.addHeader("Accept", "application/json") - authorize(request) - client.use { httpClient -> - val response: HttpResponse = httpClient.execute(request) - val entity = response.entity - return EntityUtils.toString(entity) - } - } - - fun getClient(thread: Thread = Thread.currentThread()): CloseableHttpClient = - if (thread in clients) clients[thread]!! - else synchronized(clients) { - val client = HttpClientBuilder.create().build() - clients.put(thread, client) - client - } - - fun text_to_speech(wavAudio: ByteArray, prompt: String = ""): String { - val url = apiBase + "/audio/transcriptions" - val request = HttpPost(url) - request.addHeader("Accept", "application/json") - authorize(request) - val entity = MultipartEntityBuilder.create() - entity.setMode(HttpMultipartMode.RFC6532) - entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav") - entity.addTextBody("model", "whisper-1") - if (!prompt.isEmpty()) entity.addTextBody("prompt", prompt) - request.entity = entity.build() - val response = post(request, 3) - val jsonObject = Gson().fromJson(response, JsonObject::class.java) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - return jsonObject.get("text").asString!! - } - - fun text_to_image(prompt: String = "", resolution: Int = 1024, count: Int = 1): List { - val url = apiBase + "/images/generations" - val request = HttpPost(url) - request.addHeader("Accept", "application/json") - request.addHeader("Content-Type", "application/json") - authorize(request) - val jsonObject = JsonObject() - jsonObject.addProperty("prompt", prompt) - jsonObject.addProperty("n", count) - jsonObject.addProperty("size", "${resolution}x$resolution") - request.entity = StringEntity(jsonObject.toString()) - val response = post(request, 3) - val jsonObject2 = Gson().fromJson(response, JsonObject::class.java) - if (jsonObject2.has("error")) { - val errorObject = jsonObject2.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - val dataArray = jsonObject2.getAsJsonArray("data") - val images = ArrayList() - for (i in 0 until dataArray.size()) { - images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString))) - } - return images - } - - @Throws(IOException::class) - fun processCompletionResponse(result: String): CompletionResponse { - checkError(result) - val response = mapper.readValue( - result, - CompletionResponse::class.java - ) - if (response.usage != null) { - incrementTokens(response.usage!!.total_tokens) - } - return response - } - - @Throws(IOException::class) - fun processChatResponse(result: String): ChatResponse { - checkError(result) - val response = mapper.readValue( - result, - ChatResponse::class.java - ) - if (response.usage != null) { - incrementTokens(response.usage!!.total_tokens) - } - return response - } - - private fun checkError(result: String) { - try { - val jsonObject = Gson().fromJson( - result, - JsonObject::class.java - ) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - val errorMessage = errorObject["message"].asString - if (errorMessage.startsWith("That model is currently overloaded with other requests.")) { - throw RequestOverloadException(errorMessage) - } - val matcher = maxTokenErrorMessage.matcher(errorMessage) - if (matcher.find()) { - val modelMax = matcher.group(1).toInt() - val request = matcher.group(2).toInt() - val messages = matcher.group(3).toInt() - val completion = matcher.group(4).toInt() - throw ModelMaxException(modelMax, request, messages, completion) - } - throw IOException(errorMessage) - } - } catch (e: com.google.gson.JsonSyntaxException) { - throw IOException("Invalid JSON response: $result") - } - } - - class RequestOverloadException(message: String = "That model is currently overloaded with other requests.") : - IOException(message) - - open fun incrementTokens(totalTokens: Int) {} - - companion object { - val log = Logger.getInstance(CoreAPI::class.java) - - fun log(level: LogLevel, msg: String) { - val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") - when (level) { - LogLevel.Error -> log.error(message) - LogLevel.Warn -> log.warn(message) - LogLevel.Info -> log.info(message) - else -> log.debug(message) - } - } - - val mapper: ObjectMapper - get() { - val mapper = ObjectMapper() - mapper - .enable(SerializationFeature.INDENT_OUTPUT) - .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) - .enable(MapperFeature.USE_STD_BEAN_NAMING) - .setSerializationInclusion(JsonInclude.Include.NON_NULL) - .activateDefaultTyping(mapper.polymorphicTypeValidator) - return mapper - } - val allowedCharset = Charset.forName("ASCII") - private val maxTokenErrorMessage = Pattern.compile( - """This model's maximum context length is (\d+) tokens. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\). Please reduce the length of the messages or completion.""" - ) - - private val threadFactory: ThreadFactory = ThreadFactoryBuilder().setNameFormat("API Thread %d").build() - private val scheduledPool: ListeningScheduledExecutorService = - MoreExecutors.listeningDecorator(ScheduledThreadPoolExecutor(1, threadFactory)) - } - - fun withCancellationMonitor(fn: () -> T, cancelCheck: () -> Boolean) : T { - val thread = Thread.currentThread() - val isCompleted = AtomicBoolean(false) - val future = scheduledPool.scheduleAtFixedRate({ - if (cancelCheck()) { - log.warn("Request cancelled") - closeClient(thread) - } - while(true) { - Thread.sleep(1000) - if(!isCompleted.get()) { - log.warn("Request still not completed; killing thread $thread") - thread.stop() - } else { - break - } - } - }, 0, 10, TimeUnit.MILLISECONDS) - try { - return fn() - } finally { - isCompleted.set(true) - future.cancel(false) - } - } - - fun withTimeout(duration: Duration, fn: () -> T) : T { - val thread = Thread.currentThread() - val isCompleted = AtomicBoolean(false) - val future = scheduledPool.schedule({ - log.warn("Request timed out after $duration; closing client for thread $thread") - closeClient(thread) - while(true) { - Thread.sleep(1000) - if(!isCompleted.get()) { - log.warn("Request still not completed; killing thread $thread") - thread.stop() - } else { - break - } - } - }, duration.toMillis(), TimeUnit.MILLISECONDS) - try { - return fn() - } finally { - isCompleted.set(true) - future.cancel(false) - } - } - - fun complete( - completionRequest: CompletionRequest, - model: String - ): CompletionResponse { - logStart(completionRequest) - val completionResponse = try { - val request: String = - StringTools.restrictCharacterSet( - mapper.writeValueAsString(completionRequest), - allowedCharset - ) - val result = - post(apiBase + "/engines/" + model + "/completions", request) - processCompletionResponse(result) - } catch (e: ModelMaxException) { - completionRequest.max_tokens = (e.modelMax - e.messages) - 1 - val request: String = - StringTools.restrictCharacterSet( - mapper.writeValueAsString(completionRequest), - allowedCharset - ) - val result = - post(apiBase + "/engines/" + model + "/completions", request) - processCompletionResponse(result) - } - val completionResult = StringTools.stripPrefix( - completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' }, - completionRequest.prompt.trim { it <= ' ' }) - logComplete(completionResult) - return completionResponse - } - - fun chat( - completionRequest: ChatRequest - ): ChatResponse { - logStart(completionRequest) - val url = apiBase + "/chat/completions" - val completionResponse = try { - processChatResponse( - post( - url, StringTools.restrictCharacterSet( - mapper.writeValueAsString(completionRequest), - allowedCharset - ) - ) - ) - } catch (e: ModelMaxException) { - completionRequest.max_tokens = (e.modelMax - e.messages) - 1 - processChatResponse( - post( - url, StringTools.restrictCharacterSet( - mapper.writeValueAsString(completionRequest), - allowedCharset - ) - ) - ) - } - logComplete(completionResponse.choices.first().message!!.content!!.trim { it <= ' ' }) - return completionResponse - } - - private fun logStart(completionRequest: ChatRequest) { - log( - logLevel, String.format( - "Chat Request\nPrefix:\n\t%s\n", - mapper.writeValueAsString(completionRequest).replace("\n", "\n\t") - ) - ) - } - - fun moderate(text: String) { - val body: String = try { - mapper.writeValueAsString( - Map.of( - "input", - StringTools.restrictCharacterSet(text, allowedCharset) - ) - ) - } catch (e: JsonProcessingException) { - throw RuntimeException(e) - } - val result: String = try { - this.post(apiBase + "/moderations", body) - } catch (e: IOException) { - throw RuntimeException(e) - } catch (e: InterruptedException) { - throw RuntimeException(e) - } - val jsonObject = - Gson().fromJson( - result, - JsonObject::class.java - ) - if (jsonObject.has("error")) { - val errorObject = jsonObject.getAsJsonObject("error") - throw RuntimeException(IOException(errorObject["message"].asString)) - } - val moderationResult = - jsonObject.getAsJsonArray("results")[0].asJsonObject - AsyncAPI.log( - LogLevel.Debug, - String.format( - "Moderation Request\nText:\n%s\n\nResult:\n%s", - text.replace("\n", "\n\t"), - result - ) - ) - if (moderationResult["flagged"].asBoolean) { - val categoriesObj = - moderationResult["categories"].asJsonObject - throw RuntimeException( - ModerationException( - "Moderation flagged this request due to " + categoriesObj.keySet() - .stream().filter { c: String? -> - categoriesObj[c].asBoolean - }.reduce { a: String, b: String -> "$a, $b" } - .orElse("???") - ) - ) - } - } - - fun closeClient(thread: Thread) { - try { - synchronized(clients) { - clients[thread] - }?.close() - } catch (e: IOException) { - log.warn("Error closing client: " + e.message) - } - } - -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt b/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt new file mode 100644 index 00000000..d4e193f1 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/HttpClientManager.kt @@ -0,0 +1,138 @@ +package com.github.simiacryptus.openai + +import com.google.common.util.concurrent.ListeningExecutorService +import com.google.common.util.concurrent.ListeningScheduledExecutorService +import com.google.common.util.concurrent.MoreExecutors +import com.google.common.util.concurrent.ThreadFactoryBuilder +import com.intellij.openapi.diagnostic.Logger +import org.apache.http.impl.client.CloseableHttpClient +import org.apache.http.impl.client.HttpClientBuilder +import java.io.IOException +import java.time.Duration +import java.util.* +import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.math.pow + +@Suppress("MemberVisibilityCanBePrivate") +open class HttpClientManager { + + companion object { + val log = Logger.getInstance(OpenAIClient::class.java) + val threadFactory: ThreadFactory = ThreadFactoryBuilder().setNameFormat("API Thread %d").build() + val scheduledPool: ListeningScheduledExecutorService = + MoreExecutors.listeningDecorator(ScheduledThreadPoolExecutor(4, threadFactory)) + val workPool: ListeningExecutorService = + MoreExecutors.listeningDecorator( + ThreadPoolExecutor( + 1, 32, + 0, TimeUnit.MILLISECONDS, LinkedBlockingQueue(), threadFactory + ) + ) + + fun withPool(fn: () -> T): T = workPool.submit(Callable { + return@Callable fn() + }).get() + + fun withExpBackoffRetry(retryCount: Int = 3, sleepScale: Long = 1000L, fn: () -> T): T { + var lastException: Exception? = null + for (i in 0 until retryCount) { + try { + return fn() + } catch (e: Exception) { + lastException = e + log.info("Request failed; retrying ($i/$retryCount): " + e.message) + Thread.sleep(sleepScale * 2.0.pow(i.toDouble()).toLong()) + } + } + throw lastException!! + } + + } + + protected val clients: MutableMap = WeakHashMap() + fun getClient(thread: Thread = Thread.currentThread()): CloseableHttpClient = + if (thread in clients) clients[thread]!! + else synchronized(clients) { + val client = HttpClientBuilder.create().build() + clients[thread] = client + client + } + + fun closeClient(thread: Thread) { + try { + synchronized(clients) { + clients[thread] + }?.close() + } catch (e: IOException) { + log.info("Error closing client: " + e.message) + } + } + + fun withCancellationMonitor(fn: () -> T, cancelCheck: () -> Boolean = { Thread.currentThread().isInterrupted }): T { + val thread = Thread.currentThread() + val isCompleted = AtomicBoolean(false) + val start = Date() + val future = scheduledPool.scheduleAtFixedRate({ + if (cancelCheck()) { + log.info("Request cancelled at ${Date()} (started $start); closing client for thread $thread") + closeClient(thread, isCompleted) + } + }, 0, 10, TimeUnit.MILLISECONDS) + try { + return fn() + } finally { + isCompleted.set(true) + future.cancel(false) + } + } + + fun withTimeout(duration: Duration, fn: () -> T): T { + val thread = Thread.currentThread() + val isCompleted = AtomicBoolean(false) + val start = Date() + val future = scheduledPool.schedule({ + log.info("Request timed out after $duration at ${Date()} (started $start); closing client for thread $thread") + closeClient(thread, isCompleted) + }, duration.toMillis(), TimeUnit.MILLISECONDS) + try { + return fn() + } finally { + isCompleted.set(true) + future.cancel(false) + } + } + + private fun closeClient(thread: Thread, isCompleted: AtomicBoolean) { + closeClient(thread) + Thread.sleep(10) + while (isCompleted.get()) { + Thread.sleep(5000) + if (isCompleted.get()) break + log.info("Request still not completed; thread stack: \n\t${thread.stackTrace.joinToString { "\n\t" }}\nkilling thread $thread") + @Suppress("DEPRECATION") + thread.stop() + } + } + + fun withReliability(requestTimeoutSeconds: Long = 180, retryCount: Int = 3, fn: () -> T): T = + withExpBackoffRetry(retryCount) { +// withPool { +// } + withTimeout(Duration.ofSeconds(requestTimeoutSeconds)) { + withCancellationMonitor(fn) + } + } + + fun withPerformanceLogging(fn: () -> T):T { + val start = Date() + try { + return fn() + } finally { + log.debug("Request completed in ${Date().time - start.time}ms") + } + } + + + +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt b/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt new file mode 100644 index 00000000..bef77aa3 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/OpenAIClient.kt @@ -0,0 +1,472 @@ +package com.github.simiacryptus.openai + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.databind.MapperFeature +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import com.fasterxml.jackson.databind.node.ObjectNode +import com.github.simiacryptus.aicoder.openai.ui.OpenAI_API +import com.github.simiacryptus.aicoder.util.UITools +import com.github.simiacryptus.util.StringTools +import com.google.gson.Gson +import com.google.gson.JsonObject +import com.intellij.openapi.diagnostic.Logger +import com.jetbrains.rd.util.LogLevel +import org.apache.http.HttpResponse +import org.apache.http.client.methods.HttpGet +import org.apache.http.client.methods.HttpPost +import org.apache.http.client.methods.HttpRequestBase +import org.apache.http.entity.ContentType +import org.apache.http.entity.StringEntity +import org.apache.http.entity.mime.HttpMultipartMode +import org.apache.http.entity.mime.MultipartEntityBuilder +import org.apache.http.util.EntityUtils +import java.awt.image.BufferedImage +import java.io.IOException +import java.net.URL +import java.nio.charset.Charset +import java.util.* +import java.util.Map +import java.util.regex.Pattern +import javax.imageio.ImageIO +import kotlin.collections.ArrayList +import kotlin.collections.List +import kotlin.collections.first +import kotlin.collections.set + +@Suppress("unused") +open class OpenAIClient( + private val apiBase: String, + var key: String, + private val logLevel: LogLevel +) : HttpClientManager() { + + fun getEngines(): Array { + val engines = mapper.readValue( + get(OpenAI_API.settingsState!!.apiBase + "/engines"), + ObjectNode::class.java + ) + val data = engines["data"] + val items = + arrayOfNulls(data.size()) + for (i in 0 until data.size()) { + items[i] = data[i]["id"].asText() + } + Arrays.sort(items) + return items + } + + private fun logComplete(completionResult: CharSequence) { + log( + logLevel, String.format( + "Chat Completion:\n\t%s", + completionResult.toString().replace("\n", "\n\t") + ) + ) + } + + private fun logStart(completionRequest: CompletionRequest) { + if (completionRequest.suffix == null) { + log( + logLevel, String.format( + "Text Completion Request\nPrefix:\n\t%s\n", + completionRequest.prompt.replace("\n", "\n\t") + ) + ) + } else { + log( + logLevel, String.format( + "Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\n", + completionRequest.prompt.replace("\n", "\n\t"), + completionRequest.suffix!!.replace("\n", "\n\t") + ) + ) + } + } + + @Throws(IOException::class, InterruptedException::class) + protected fun post(url: String, json: String): String { + return post(jsonRequest(url, json)) + } + + private fun post( + request: HttpPost + ): String { + val client = getClient() + try { + client.use { httpClient -> + synchronized(clients) { + clients[Thread.currentThread()] = httpClient + } + val response: HttpResponse = httpClient.execute(request) + val entity = response.entity + return EntityUtils.toString(entity) + } + } finally { + synchronized(clients) { + clients.remove(Thread.currentThread()) + } + } + } + + private fun jsonRequest(url: String, json: String): HttpPost { + val request = HttpPost(url) + request.addHeader("Content-Type", "application/json") + request.addHeader("Accept", "application/json") + authorize(request) + request.entity = StringEntity(json) + return request + } + + @Throws(IOException::class) + protected fun authorize(request: HttpRequestBase) { + var apiKey: CharSequence = key + if (apiKey.length == 0) { + synchronized(OpenAI_API.javaClass) { + apiKey = key + if (apiKey.length == 0) { + apiKey = UITools.queryAPIKey()!! + key = apiKey.toString() + } + } + } + request.addHeader("Authorization", "Bearer $apiKey") + } + + /** + * Gets the response from the given URL. + * + * @param url The URL to GET the response from. + * @return The response from the given URL. + * @throws IOException If an I/O error occurs. + */ + @Throws(IOException::class) + protected operator fun get(url: String?): String { + val client = getClient() + val request = HttpGet(url) + request.addHeader("Content-Type", "application/json") + request.addHeader("Accept", "application/json") + authorize(request) + client.use { httpClient -> + val response: HttpResponse = httpClient.execute(request) + val entity = response.entity + return EntityUtils.toString(entity) + } + } + + fun dictate(wavAudio: ByteArray, prompt: String = ""): String = withReliability { + withPerformanceLogging { + val url = apiBase + "/audio/transcriptions" + val request = HttpPost(url) + request.addHeader("Accept", "application/json") + authorize(request) + val entity = MultipartEntityBuilder.create() + entity.setMode(HttpMultipartMode.RFC6532) + entity.addBinaryBody("file", wavAudio, ContentType.create("audio/x-wav"), "audio.wav") + entity.addTextBody("model", "whisper-1") + if (!prompt.isEmpty()) entity.addTextBody("prompt", prompt) + request.entity = entity.build() + val response = post(request) + val jsonObject = Gson().fromJson(response, JsonObject::class.java) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + jsonObject.get("text").asString!! + } + } + + fun render(prompt: String = "", resolution: Int = 1024, count: Int = 1): List = withReliability { + withPerformanceLogging { + val url = apiBase + "/images/generations" + val request = HttpPost(url) + request.addHeader("Accept", "application/json") + request.addHeader("Content-Type", "application/json") + authorize(request) + val jsonObject = JsonObject() + jsonObject.addProperty("prompt", prompt) + jsonObject.addProperty("n", count) + jsonObject.addProperty("size", "${resolution}x$resolution") + request.entity = StringEntity(jsonObject.toString()) + val response = post(request) + val jsonObject2 = Gson().fromJson(response, JsonObject::class.java) + if (jsonObject2.has("error")) { + val errorObject = jsonObject2.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + val dataArray = jsonObject2.getAsJsonArray("data") + val images = ArrayList() + for (i in 0 until dataArray.size()) { + images.add(ImageIO.read(URL(dataArray[i].asJsonObject.get("url").asString))) + } + images + } } + + @Throws(IOException::class) + private fun processCompletionResponse(result: String): CompletionResponse { + checkError(result) + val response = mapper.readValue( + result, + CompletionResponse::class.java + ) + if (response.usage != null) { + incrementTokens(response.usage!!.total_tokens) + } + return response + } + + @Throws(IOException::class) + protected fun processChatResponse(result: String): ChatResponse { + checkError(result) + val response = mapper.readValue( + result, + ChatResponse::class.java + ) + if (response.usage != null) { + incrementTokens(response.usage!!.total_tokens) + } + return response + } + + private fun checkError(result: String) { + try { + val jsonObject = Gson().fromJson( + result, + JsonObject::class.java + ) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + val errorMessage = errorObject["message"].asString + if (errorMessage.startsWith("That model is currently overloaded with other requests.")) { + throw RequestOverloadException(errorMessage) + } + maxTokenErrorMessage.find { it.matcher(errorMessage).matches() }?.let { + val matcher = it.matcher(errorMessage) + if (matcher.find()) { + val modelMax = matcher.group(1).toInt() + val request = matcher.group(2).toInt() + val messages = matcher.group(3).toInt() + val completion = matcher.group(4).toInt() + throw ModelMaxException(modelMax, request, messages, completion) + } + } + throw IOException(errorMessage) + } + } catch (e: com.google.gson.JsonSyntaxException) { + throw IOException("Invalid JSON response: $result") + } + } + + class RequestOverloadException(message: String = "That model is currently overloaded with other requests.") : + IOException(message) + + open fun incrementTokens(totalTokens: Int) {} + + companion object { + val log = Logger.getInstance(OpenAIClient::class.java) + + fun log(level: LogLevel, msg: String) { + val message = msg.trim { it <= ' ' }.replace("\n", "\n\t") + when (level) { + LogLevel.Error -> log.error(message) + LogLevel.Warn -> log.warn(message) + LogLevel.Info -> log.info(message) + else -> log.debug(message) + } + } + + val mapper: ObjectMapper + get() { + val mapper = ObjectMapper() + mapper + .enable(SerializationFeature.INDENT_OUTPUT) + .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) + .enable(MapperFeature.USE_STD_BEAN_NAMING) + .setSerializationInclusion(JsonInclude.Include.NON_NULL) + .activateDefaultTyping(mapper.polymorphicTypeValidator) + return mapper + } + val allowedCharset = Charset.forName("ASCII") + private val maxTokenErrorMessage = listOf( + Pattern.compile( + """This model's maximum context length is (\d+) tokens. However, you requested (\d+) tokens \((\d+) in the messages, (\d+) in the completion\).*""" + ), + // This model's maximum context length is 4097 tokens, however you requested 80052 tokens (52 in your prompt; 80000 for the completion). Please reduce your prompt; or completion length. + Pattern.compile( + """This model's maximum context length is (\d+) tokens, however you requested (\d+) tokens \((\d+) in your prompt; (\d+) for the completion\).*""" + ) + ) + + } + + fun complete( + completionRequest: CompletionRequest, + model: String + ): CompletionResponse = withReliability { + withPerformanceLogging { + logStart(completionRequest) + val completionResponse = try { + val request: String = + StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + val result = + post(apiBase + "/engines/" + model + "/completions", request) + processCompletionResponse(result) + } catch (e: ModelMaxException) { + completionRequest.max_tokens = (e.modelMax - e.messages) - 1 + val request: String = + StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + val result = + post(apiBase + "/engines/" + model + "/completions", request) + processCompletionResponse(result) + } + val completionResult = StringTools.stripPrefix( + completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' }, + completionRequest.prompt.trim { it <= ' ' }) + logComplete(completionResult) + completionResponse + } + } + + fun chat( + completionRequest: ChatRequest + ): ChatResponse = withReliability { + withPerformanceLogging { + logStart(completionRequest) + val url = apiBase + "/chat/completions" + val completionResponse = try { + processChatResponse( + post( + url, StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + ) + ) + } catch (e: ModelMaxException) { + completionRequest.max_tokens = (e.modelMax - e.messages) - 1 + processChatResponse( + post( + url, StringTools.restrictCharacterSet( + mapper.writeValueAsString(completionRequest), + allowedCharset + ) + ) + ) + } + logComplete(completionResponse.choices.first().message!!.content!!.trim { it <= ' ' }) + completionResponse + } + } + + private fun logStart(completionRequest: ChatRequest) { + log( + logLevel, String.format( + "Chat Request\nPrefix:\n\t%s\n", + mapper.writeValueAsString(completionRequest).replace("\n", "\n\t") + ) + ) + } + + fun moderate(text: String) = withReliability { + withPerformanceLogging { + val body: String = try { + mapper.writeValueAsString( + Map.of( + "input", + StringTools.restrictCharacterSet(text, allowedCharset) + ) + ) + } catch (e: JsonProcessingException) { + throw RuntimeException(e) + } + val result: String = try { + this.post(apiBase + "/moderations", body) + } catch (e: IOException) { + throw RuntimeException(e) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + val jsonObject = + Gson().fromJson( + result, + JsonObject::class.java + ) + if (jsonObject.has("error")) { + val errorObject = jsonObject.getAsJsonObject("error") + throw RuntimeException(IOException(errorObject["message"].asString)) + } + val moderationResult = + jsonObject.getAsJsonArray("results")[0].asJsonObject + log( + LogLevel.Debug, + String.format( + "Moderation Request\nText:\n%s\n\nResult:\n%s", + text.replace("\n", "\n\t"), + result + ) + ) + if (moderationResult["flagged"].asBoolean) { + val categoriesObj = + moderationResult["categories"].asJsonObject + throw RuntimeException( + ModerationException( + "Moderation flagged this request due to " + categoriesObj.keySet() + .stream().filter { c: String? -> + categoriesObj[c].asBoolean + }.reduce { a: String, b: String -> "$a, $b" } + .orElse("???") + ) + ) + } + } + } + + fun edit( + editRequest: EditRequest + ): CompletionResponse = withReliability { + withPerformanceLogging { + logStart(editRequest, logLevel) + val request: String = + StringTools.restrictCharacterSet( + OpenAIClient.mapper.writeValueAsString(editRequest), + allowedCharset + ) + val result = post(apiBase + "/edits", request) + val completionResponse = processCompletionResponse(result) + logComplete( + completionResponse.firstChoice.orElse("").toString().trim { it <= ' ' } + ) + completionResponse + } + } + + private fun logStart( + editRequest: EditRequest, + level: LogLevel + ) { + if (editRequest.input == null) { + log( + level, String.format( + "Text Edit Request\nInstruction:\n\t%s\n", + editRequest.instruction.replace("\n", "\n\t") + ) + ) + } else { + log( + level, String.format( + "Text Edit Request\nInstruction:\n\t%s\nInput:\n\t%s\n", + editRequest.instruction.replace("\n", "\n\t"), + editRequest.input!!.replace("\n", "\n\t") + ) + ) + } + } + +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt index 7bc8dfcd..161fae41 100644 --- a/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ChatProxy.kt @@ -1,32 +1,31 @@ package com.github.simiacryptus.openai.proxy -import com.github.simiacryptus.openai.ChatRequest import com.github.simiacryptus.openai.ChatMessage -import com.github.simiacryptus.openai.CoreAPI +import com.github.simiacryptus.openai.ChatRequest +import com.github.simiacryptus.openai.OpenAIClient import com.jetbrains.rd.util.LogLevel -import java.time.Duration import java.util.concurrent.atomic.AtomicInteger @Suppress("MemberVisibilityCanBePrivate") class ChatProxy( apiKey: String, - private val model: String = "gpt-3.5-turbo", - private val maxTokens: Int = 3500, - private val temperature: Double = 0.7, - private val verbose: Boolean = false, + var model: String = "gpt-3.5-turbo", + var maxTokens: Int = 3500, + var temperature: Double = 0.7, + var verbose: Boolean = false, private val moderated: Boolean = true, base: String = "https://api.openai.com/v1", apiLog: String? = null, logLevel: LogLevel ) : GPTProxyBase(apiLog, 3) { - val api: CoreAPI + val api: OpenAIClient val totalPrefixLength = AtomicInteger(0) val totalSuffixLength = AtomicInteger(0) val totalInputLength = AtomicInteger(0) val totalOutputLength = AtomicInteger(0) init { - api = CoreAPI(base, apiKey, logLevel) + api = OpenAIClient(base, apiKey, logLevel) } override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { @@ -36,14 +35,12 @@ class ChatProxy( listOf( ChatMessage( ChatMessage.Role.system, """ - |You are a JSON-RPC Service serving the following method: - |${prompt.methodName} - |Requests contain the following arguments: - |${prompt.argList.keys.joinToString("\n ")} - |Responses are of type: - |${prompt.responseType} + |You are a JSON-RPC Service |Responses are expected to be a single JSON object without explaining text. |All input arguments are optional + |You will respond to the following method: + | + |${prompt.apiYaml} |""".trimMargin().trim() ) ) + @@ -69,12 +66,11 @@ class ChatProxy( request.model = model request.max_tokens = maxTokens request.temperature = temperature - val completion = api.withTimeout(Duration.ofMinutes(10)) { - val json = toJson(request) - if (moderated) api.moderate(json) - totalInputLength.addAndGet(json.length) - api.chat(request).response.get().toString() - } + val json = toJson(request) + if (moderated) api.moderate(json) + totalInputLength.addAndGet(json.length) + + val completion = api.chat(request).response.get().toString() if (verbose) println(completion) totalOutputLength.addAndGet(completion.length) val trimPrefix = trimPrefix(completion) diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt index 19c0defa..c1a7c5c7 100644 --- a/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/CompletionProxy.kt @@ -1,43 +1,42 @@ package com.github.simiacryptus.openai.proxy import com.github.simiacryptus.openai.CompletionRequest -import com.github.simiacryptus.openai.CoreAPI +import com.github.simiacryptus.openai.OpenAIClient import com.jetbrains.rd.util.LogLevel class CompletionProxy( apiKey: String, private val model: String = "text-davinci-003", - private val maxTokens: Int = 1000, + private val maxTokens: Int = 4000, private val temperature: Double = 0.7, private val verbose: Boolean = false, private val moderated: Boolean = true, base: String = "https://api.openai.com/v1", apiLog: String ) : GPTProxyBase(apiLog, 3) { - val api: CoreAPI + val api: OpenAIClient init { - api = CoreAPI(base, apiKey, LogLevel.Debug) + api = OpenAIClient(base, apiKey, LogLevel.Debug) } override fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String { if(verbose) println(prompt) val request = CompletionRequest() - val argList = prompt.argList - val methodName = prompt.methodName - val responseType = prompt.responseType request.prompt = """ - Method: $methodName - Response Type: - ${responseType.replace("\n", "\n ")} - Request: - { - ${argList.entries.joinToString(",\n", transform = { (argName, argValue) -> - """"$argName": $argValue""" - }).replace("\n", "\n ")} - } - Response: - {""".trimIndent() + |Method: ${prompt.methodName} + |Response Type: + | ${prompt.apiYaml.replace("\n", "\n ")} + |Request: + | { + | ${ + prompt.argList.entries.joinToString(",\n", transform = { (argName, argValue) -> + """"$argName": $argValue""" + }).replace("\n", "\n ") + } + | } + |Response: + | {""".trim().trimIndent() request.max_tokens = maxTokens request.temperature = temperature if (moderated) api.moderate(toJson(request)) diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/Notes.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt similarity index 63% rename from src/main/kotlin/com/github/simiacryptus/openai/proxy/Notes.kt rename to src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt index 5bcf1126..3c6643df 100644 --- a/src/main/kotlin/com/github/simiacryptus/openai/proxy/Notes.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/Description.kt @@ -1,4 +1,4 @@ package com.github.simiacryptus.openai.proxy @Retention(AnnotationRetention.RUNTIME) -annotation class Notes(val value: String) \ No newline at end of file +annotation class Description(val value: String) \ No newline at end of file diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt index 6c9a4f36..b05b0445 100644 --- a/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/GPTProxyBase.kt @@ -1,25 +1,29 @@ package com.github.simiacryptus.openai.proxy -import com.fasterxml.jackson.core.JsonParseException import com.fasterxml.jackson.core.json.JsonReadFeature import com.fasterxml.jackson.databind.ObjectMapper -import com.github.simiacryptus.util.StringTools.indentJoin +import com.fasterxml.jackson.module.kotlin.isKotlinClass +import com.google.common.reflect.TypeToken import java.io.BufferedWriter import java.io.File import java.io.FileWriter import java.lang.reflect.* -import kotlin.reflect.KClass +import kotlin.reflect.KProperty1 +import kotlin.reflect.full.memberProperties +import kotlin.reflect.jvm.javaType + abstract class GPTProxyBase( apiLogFile: String?, - private val deserializerRetries: Int = 3 + private val deserializerRetries: Int = 5 ) { abstract fun complete(prompt: ProxyRequest, vararg examples: ProxyRecord): String fun create(clazz: Class): T { return Proxy.newProxyInstance(clazz.classLoader, arrayOf(clazz)) { proxy, method, args -> if (method.name == "toString") return@newProxyInstance clazz.simpleName - val typeString = typeToName(method.genericReturnType) + val type = method.genericReturnType + val typeString = method.toYaml().trimIndent() val prompt = ProxyRequest( method.name, typeString, @@ -29,25 +33,42 @@ abstract class GPTProxyBase( param.name to toJson(arg!!) }.toMap() ) + + var lastException: Exception? = null for (retry in 0 until deserializerRetries) { - val result = complete(prompt, *examples[method.name]?.toTypedArray() ?: arrayOf()) - writeToJsonLog(ProxyRecord(prompt.methodName, prompt.argList, result)) + var result = complete(prompt, *examples[method.name]?.toTypedArray() ?: arrayOf()) + // If the requested `type` is a list, check that result is a list + if (type is ParameterizedType && List::class.java.isAssignableFrom(type.rawType as Class<*>) && !result.startsWith( + "[" + ) + ) { + result = "[$result]" + } + writeToJsonLog(ProxyRecord(method.name, prompt.argList, result)) try { - return@newProxyInstance fromJson(result, method.genericReturnType) - } catch (e: JsonParseException) { + val obj = fromJson(type, result) + if (obj is ValidatedObject && !obj.validate()) { + log.warn("Invalid response: $result") + continue + } + return@newProxyInstance obj + } catch (e: Exception) { log.warn("Failed to parse response: $result", e) - log.info("Retrying...") + lastException = e + log.info("Retry $retry of $deserializerRetries") } } + throw RuntimeException("Failed to parse response", lastException) } as T } + private val apiLog = apiLogFile?.let { openApiLog(it) } private val examples = HashMap>() private fun loadExamples(file: File = File("api.examples.json")): List { if (!file.exists()) return listOf() val json = file.readText() - return fromJson(json, object : ArrayList() {}.javaClass) + return fromJson(object : ArrayList() {}.javaClass, json) } fun addExamples(file: File) { @@ -75,39 +96,9 @@ abstract class GPTProxyBase( return objectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(data) } - open fun fromJson(data: String, type: Type): T { - if (data.isNotEmpty()) try { + open fun fromJson(type: Type, data: String): T { if (type is Class<*> && type.isAssignableFrom(String::class.java)) return data as T - return objectMapper().readValue(data, objectMapper().typeFactory.constructType(type)) as T - } catch (e: JsonParseException) { - throw e - } catch (e: Exception) { - log.error("Error parsing JSON", e) - } - if (type is Class<*>) return newInstance(type) as T - return null as T - } - - open fun newInstance(type: Class): T { - if (type.isAssignableFrom(String::class.java)) return "" as T - if (type.isAssignableFrom(Boolean::class.java)) return false as T - if (type.isAssignableFrom(Int::class.java)) return 0 as T - if (type.isAssignableFrom(Long::class.java)) return 0L as T - if (type.isAssignableFrom(Double::class.java)) return 0.0 as T - if (type.isAssignableFrom(Float::class.java)) return 0.0f as T - if (type.isAssignableFrom(Short::class.java)) return 0 as T - if (type.isAssignableFrom(Byte::class.java)) return 0 as T - if (type.isAssignableFrom(Char::class.java)) return 0 as T - if (type.isAssignableFrom(Void::class.java)) return null as T - if (type.isAssignableFrom(Any::class.java)) return null as T - if (type.isAssignableFrom(Unit::class.java)) return null as T - if (type.isAssignableFrom(Nothing::class.java)) return null as T - if (type.isAssignableFrom(List::class.java)) return listOf() as T - if (type.isAssignableFrom(Map::class.java)) return mapOf() as T - if (type.isAssignableFrom(Set::class.java)) return setOf() as T - if (type.isAssignableFrom(Array::class.java)) return arrayOf() as T - if (type.isAssignableFrom(Iterable::class.java)) return listOf() as T - return type.getConstructor().newInstance() + return objectMapper().readValue(data, objectMapper().typeFactory.constructType(type)) as T } open fun objectMapper(): ObjectMapper { @@ -118,7 +109,7 @@ abstract class GPTProxyBase( data class ProxyRequest( val methodName: String = "", - val responseType: String = "", + val apiYaml: String = "", val argList: Map = mapOf() ) @@ -131,160 +122,118 @@ abstract class GPTProxyBase( companion object { val log = org.slf4j.LoggerFactory.getLogger(GPTProxyBase::class.java) - fun typeToName(type: Type?): String { - // Convert a type to API documentation including type name and type structure, recusively expanding child types - if (type == null) { - return "null" - } - if(type is KClass<*>) { - return typeToName(type.java) - } - val javaClass = if (type is Class<*>) { - type - } else if (type is ParameterizedType) { - type.rawType as Class<*> - } else if (type is GenericArrayType) { - type.genericComponentType as Class<*> - } else if (type is TypeVariable<*>) { - type.bounds[0] as Class<*> - } else if (type is WildcardType) { - type.upperBounds[0] as Class<*> - } else if (type is KClass<*>) { - type.java + fun Parameter.toYaml(): String { + val description = getAnnotation(Description::class.java)?.value + val yaml = if (description != null) { + """ + |- name: ${this.name} + | description: $description + | ${this.parameterizedType.toYaml().replace("\n", "\n ")} + |""".trimMargin().trim() } else { - null - } - if (javaClass != null) { - if (javaClass.isPrimitive) { - return javaClass.simpleName - } - if (javaClass.isArray) { - return "Array<${typeToName(javaClass.componentType)}>" - } - if (javaClass.isEnum) { - return javaClass.simpleName - } - if (javaClass.isAssignableFrom(List::class.java)) { - if (type is ParameterizedType) { - val genericType = type.actualTypeArguments[0] - return "List<${typeToName(genericType as Class<*>)}>" - } else { - return "List" - } - } - if (javaClass.isAssignableFrom(Map::class.java)) { - if (type is ParameterizedType) { - val keyType = type.actualTypeArguments[0] - val valueType = type.actualTypeArguments[1] - return "Map<${typeToName(keyType as Class<*>)}, ${typeToName(valueType as Class<*>)}>" - } else { - return "Map" - } - } - if (javaClass.getPackage()?.name?.startsWith("java") == true) { - return javaClass.simpleName - } + """ + |- name: ${this.name} + | ${this.parameterizedType.toYaml().replace("\n", "\n ")} + |""".trimMargin().trim() } - return typeDescription(type).toString() + return yaml } - private fun typeDescription(clazz: Class<*>): TypeDescription { - val apiDocumentation = if (clazz.isArray) { - return TypeDescription("Array<${typeDescription(clazz.componentType)}>") - } else if (clazz.isAssignableFrom(List::class.java)) { - if (clazz.isAssignableFrom(ParameterizedType::class.java)) { - val genericType = (clazz as ParameterizedType).actualTypeArguments[0] - return TypeDescription("List<${typeDescription(genericType as Class<*>)}>") - } else { - return TypeDescription("List") - } - } else if (clazz.isAssignableFrom(Map::class.java)) { - if (clazz.isAssignableFrom(ParameterizedType::class.java)) { - val keyType = (clazz as ParameterizedType).actualTypeArguments[0] - val valueType = (clazz as ParameterizedType).actualTypeArguments[1] - return TypeDescription("Map<${typeDescription(keyType as Class<*>)}, ${typeDescription(valueType as Class<*>)}}>") - } else { - return TypeDescription("Map") - } - } else if (clazz == String::class.java) { - return TypeDescription(clazz.simpleName) - } else { - TypeDescription(clazz.simpleName) - } - if (clazz.isPrimitive) return apiDocumentation - if (clazz.isEnum) return apiDocumentation - for (field in clazz.declaredFields) { - if (field.name.startsWith("\$")) continue - var annotation = getAnnotation(field, clazz, Notes::class) - val notes = if (annotation == null) "" else { - " /* " + annotation.value + " */" - } - // Get ParameterizedType for field - val type = field.genericType - if (type is ParameterizedType) { - // Get raw type - if ((type.rawType as Class<*>).isAssignableFrom(List::class.java)) { - // Get type of list elements - val elementType = type.actualTypeArguments[0] as Class<*> - apiDocumentation.fields.add( - FieldData( - field.name + notes, - TypeDescription("List<${typeDescription(elementType)}>") - ) - ) - continue - } + fun Type.toYaml(): String { + val typeName = this.typeName.substringAfterLast('.').replace('$', '.').toLowerCase() + val yaml = if (typeName in setOf("boolean", "integer", "number", "string")) { + "type: $typeName" + } else if (this is ParameterizedType && List::class.java.isAssignableFrom(this.rawType as Class<*>)) { + """ + |type: array + |items: + | ${this.actualTypeArguments[0].toYaml().replace("\n", "\n ")} + |""".trimMargin() + } else if (this.isArray) { + """ + |type: array + |items: + | ${this.componentType?.toYaml()?.replace("\n", "\n ")} + |""".trimMargin() + } else { + val rawType = TypeToken.of(this).rawType + val declaredFieldYaml = rawType.declaredFields.map { + """ + |${it.name}: + | ${it.genericType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + }.toTypedArray() + val propertiesYaml = if (rawType.isKotlinClass() && rawType.kotlin.isData) { + rawType.kotlin.memberProperties.map { + val allAnnotations = + getAllAnnotations(rawType, it) + val description = allAnnotations.find { x -> x is Description } as? Description + // Find annotation on the kotlin data class constructor parameter + val yaml = if (description != null) { + """ + |${it.name}: + | description: ${description.value} + | ${it.returnType.javaType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + } else { + """ + |${it.name}: + | ${it.returnType.javaType.toYaml().replace("\n", "\n ")} + """.trimMargin().trim() + } + yaml + }.toTypedArray() + } else { + arrayOf() } - apiDocumentation.fields.add(FieldData(field.name + notes, typeDescription(field.genericType))) + val fieldsYaml = (declaredFieldYaml.toList() + propertiesYaml.toList()).distinct().joinToString("\n") + """ + |type: object + |properties: + | ${fieldsYaml.replace("\n", "\n ")} + """.trimMargin() } - return apiDocumentation + return yaml } - private fun getAnnotation( - field: Field, - clazz: Class<*>, - attributeClass: KClass - ): Notes? { - var annotation = field.getAnnotation(attributeClass.java) - if (annotation != null) return annotation - // If this is a kotlin data class, look for the annotation on the constructor parameter - if (clazz.kotlin.isData) { - val constructor = clazz.kotlin.constructors.first() - val parameter = constructor.parameters.firstOrNull { it.name == field.name } - if (parameter != null) { - val parameterAnnotation = parameter.annotations.firstOrNull { it is Notes } - if (parameterAnnotation != null) { - return parameterAnnotation as Notes - } - } - } - return null + private fun getAllAnnotations( + rawType: Class, + property: KProperty1 + ) = property.annotations + (rawType.kotlin.constructors.first().parameters.find { x -> x.name == property.name }?.annotations + ?: listOf()) + + fun Method.toYaml(): String { + val parameterYaml = parameters.map { it.toYaml() }.toTypedArray().joinToString("").trim() + val returnTypeYaml = genericReturnType.toYaml().trim() + val responseYaml = """ + |responses: + | application/json: + | schema: + | ${returnTypeYaml.replace("\n", "\n ")} + """.trimMargin().trim() + val yaml = """ + |operationId: ${"${declaringClass.simpleName}.$name"} + |parameters: + | ${parameterYaml.replace("\n", "\n ")} + |$responseYaml + """.trimMargin() + return yaml } - private fun typeDescription(clazz: Type): TypeDescription { - if (clazz is Class<*>) return typeDescription(clazz) - if (clazz is ParameterizedType) { - val rawType = clazz.rawType as Class<*> - if (rawType.isAssignableFrom(List::class.java)) { - // Get type of list elements - val elementType = clazz.actualTypeArguments[0] as Class<*> - return TypeDescription("List<${typeDescription(elementType)}>") - } + val Type.isArray: Boolean + get() { + return this is Class<*> && this.isArray } - return TypeDescription(clazz.typeName) - } - class TypeDescription(val name: String) { - val fields: ArrayList = ArrayList() - override fun toString(): String { - return if (fields.isEmpty()) name else indentJoin(fields) + val Type.componentType: Type? + get() { + return when (this) { + is Class<*> -> if (this.isArray) this.componentType else null + is ParameterizedType -> this.actualTypeArguments.firstOrNull() + else -> null + } } - } - - class FieldData(val name: String, val type: TypeDescription) { - override fun toString(): String = """"$name": $type""" - } } } diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/SoftwareProjectAI.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/SoftwareProjectAI.kt deleted file mode 100644 index eb1adf13..00000000 --- a/src/main/kotlin/com/github/simiacryptus/openai/proxy/SoftwareProjectAI.kt +++ /dev/null @@ -1,294 +0,0 @@ -package com.github.simiacryptus.openai.proxy - -import java.io.File -import java.util.concurrent.Callable -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.Executors -import java.util.concurrent.Future -import java.util.concurrent.atomic.AtomicInteger -import java.util.zip.ZipEntry -import java.util.zip.ZipOutputStream - -interface SoftwareProjectAI { - fun newProject(description: String): Project - - data class Project( - val name: String = "", - val description: String = "", - val language: String = "", - val features: List = listOf(), - val libraries: List = listOf(), - val buildTools: List = listOf(), - ) - - fun getProjectStatements(description: String, project: Project): ProjectStatements - - data class ProjectStatements( - val assumptions: List = listOf(), - val designPatterns: List = listOf(), - val requirements: List = listOf(), - val risks: List = listOf(), - ) - - fun buildProjectDesign(project: Project, requirements: ProjectStatements): ProjectDesign - - data class ProjectDesign( - val components: List = listOf(), - val documents: List = listOf(), - val tests: List = listOf(), - ) - - data class ComponentDetails( - val name: String = "", - val description: String = "", - val features: List = listOf(), - ) - - data class TestDetails( - val name: String = "", - val steps: List = listOf(), - val expectations: List = listOf(), - ) - - data class DocumentationDetails( - val name: String = "", - val description: String = "", - val sections: List = listOf(), - ) - - fun buildProjectFileSpecifications( - project: Project, - requirements: ProjectStatements, - design: ProjectDesign, - recursive: Boolean = true - ): List - - fun buildComponentFileSpecifications( - project: Project, - requirements: ProjectStatements, - design: ComponentDetails, - recursive: Boolean = true - ): List - - fun buildTestFileSpecifications( - project: Project, - requirements: ProjectStatements, - design: TestDetails, - recursive: Boolean = true - ): List - - fun buildDocumentationFileSpecifications( - project: Project, - requirements: ProjectStatements, - design: DocumentationDetails, - recursive: Boolean = true - ): List - - data class CodeSpecification( - val description: String = "", - val requires: List = listOf(), - val publicProperties: List = listOf(), - val publicMethodSignatures: List = listOf(), - val language: String = "", - val location: FilePath = FilePath(), - ) - - data class DocumentSpecification( - val description: String = "", - val requires: List = listOf(), - val sections: List = listOf(), - val language: String = "", - val location: FilePath = FilePath(), - ) - - data class TestSpecification( - val description: String = "", - val requires: List = listOf(), - val steps: List = listOf(), - val expectations: List = listOf(), - val language: String = "", - val location: FilePath = FilePath(), - ) - - data class FilePath( - @Notes("e.g. projectRoot/README.md") val fullFilePathName: String = "", - ) { - override fun toString(): String { - return fullFilePathName - } - } - - fun implementComponentSpecification( - project: Project, - specification: CodeSpecification, - component: ComponentDetails, - imports: List, - specificationAgain: CodeSpecification, - ): SourceCode - - - fun implementTestSpecification( - project: Project, - specification: TestSpecification, - test: TestDetails, - imports: List, - specificationAgain: TestSpecification, - ): SourceCode - - - fun implementDocumentationSpecification( - project: Project, - specification: DocumentSpecification, - documentation: DocumentationDetails, - imports: List, - specificationAgain: DocumentSpecification, - ): SourceCode - - data class SourceCode( - val language: String = "", - val code: String = "", - ) - - companion object { - val log = org.slf4j.LoggerFactory.getLogger(SoftwareProjectAI::class.java) - fun parallelImplement( - api: SoftwareProjectAI, - project: Project, - components: Map>?, - documents: Map>?, - tests: Map>?, - drafts: Int, - threads: Int - ): Map = parallelImplementWithAlternates( - api, - project, - components ?: mapOf(), - documents ?: mapOf(), - tests ?: mapOf(), - drafts, - threads - ).mapValues { it.value.maxByOrNull { it.code.length } } - - fun parallelImplementWithAlternates( - api: SoftwareProjectAI, - project: Project, - components: Map>, - documents: Map>, - tests: Map>, - drafts: Int, - threads: Int, - progress: (Double) -> Unit = {} - ): Map> { - val threadPool = Executors.newFixedThreadPool(threads) - try { - val totalDrafts = (components + tests + documents).values.filterNotNull().sumOf { it.size } * drafts - val currentDraft = AtomicInteger(0) - val fileImplCache = ConcurrentHashMap>>>() - val normalizeFileName : (String) -> String = { - it.trimStart('/','.') - } - val componentImpl = components.filter { it.value != null }.flatMap { (component, files) -> - files.flatMap { file -> - fileImplCache.getOrPut(normalizeFileName(file.location.fullFilePathName)) { - (0 until drafts).map { _ -> - threadPool.submit(Callable { - val implement = api.implementComponentSpecification( - project, - file.copy(requires = listOf()), - component, - files.filter { file.requires.contains(it.location) }.toList(), - file.copy(requires = listOf()) - ) - (currentDraft.incrementAndGet().toDouble() / totalDrafts) - .also { progress(it) } - .also { log.info("Progress: $it") } - file.location to implement - }) - } - } - } - }.toTypedArray().map { - try { - it.get() - } catch (e: Throwable) { - null - } - }.filterNotNull().groupBy { it.first } - .mapValues { it.value.map { it.second }.sortedBy { it.code.length } } - val testImpl = tests.filter { it.value != null }.flatMap { (testDetails, files) -> - files.flatMap { file -> - fileImplCache.getOrPut(normalizeFileName(file.location.fullFilePathName)) { - (0 until drafts).map { _ -> - threadPool.submit(Callable { - val implement = api.implementTestSpecification( - project, - file, - testDetails, - files.filter { file.requires.contains(it.location) }.toList(), - file - ) - (currentDraft.incrementAndGet().toDouble() / totalDrafts) - .also { progress(it) } - .also { log.info("Progress: $it") } - file.location to implement - }) - } - } - } - }.toTypedArray().map { - try { - it.get() - } catch (e: Throwable) { - null - } - }.filterNotNull().groupBy { it.first } - .mapValues { it.value.map { it.second }.sortedBy { it.code.length } } - val docImpl = documents.filter { it.value != null }.flatMap { (documentationDetails, files) -> - files.flatMap { file -> - fileImplCache.getOrPut(normalizeFileName(file.location.fullFilePathName)) { - (0 until drafts).map { _ -> - threadPool.submit(Callable { - val implement = api.implementDocumentationSpecification( - project, - file, - documentationDetails, - files.filter { file.requires.contains(it.location) }.toList(), - file, - ) - (currentDraft.incrementAndGet().toDouble() / totalDrafts) - .also { progress(it) } - .also { log.info("Progress: $it") } - file.location to implement - }) - } - } - } - }.toTypedArray().map { - try { - it.get() - } catch (e: Throwable) { - null - } - }.filterNotNull().groupBy { it.first } - .mapValues { it.value.map { it.second }.sortedBy { it.code.length } } - return componentImpl + docImpl + testImpl - } finally { - threadPool.shutdown() - } - } - - fun write( - zipArchiveFile: File, - implementations: Map - ) { - ZipOutputStream(zipArchiveFile.outputStream()).use { zip -> - implementations.forEach { (file, sourceCodes) -> - zip.putNextEntry(ZipEntry(file.toString())) - zip.write(sourceCodes!!.code.toByteArray()) - zip.closeEntry() - } - } - } - } -} - diff --git a/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt new file mode 100644 index 00000000..b12b5818 --- /dev/null +++ b/src/main/kotlin/com/github/simiacryptus/openai/proxy/ValidatedObject.kt @@ -0,0 +1,26 @@ +package com.github.simiacryptus.openai.proxy + +import kotlin.reflect.full.memberProperties + +interface ValidatedObject { + fun validate(): Boolean = validateFields(this) + + companion object { + fun validateFields(obj: Any): Boolean { + obj.javaClass.declaredFields.forEach { field -> + field.isAccessible = true + val value = field.get(obj) + if (value is ValidatedObject && !value.validate()) { + return false + } + } + obj.javaClass.kotlin.memberProperties.forEach { property -> + val value = property.getter.call(obj) + if (value is ValidatedObject && !value.validate()) { + return false + } + } + return true + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt index 48c8809e..ec22cb1f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoDevelop.kt @@ -1,8 +1,8 @@ package com.github.simiacryptus.aicoder.proxy -import com.github.simiacryptus.openai.proxy.SoftwareProjectAI -import com.github.simiacryptus.openai.proxy.SoftwareProjectAI.Companion.parallelImplement -import com.github.simiacryptus.openai.proxy.SoftwareProjectAI.Companion.write +import com.github.simiacryptus.aicoder.SoftwareProjectAI +import com.github.simiacryptus.aicoder.SoftwareProjectAI.Companion.parallelImplement +import com.github.simiacryptus.aicoder.SoftwareProjectAI.Companion.write import org.junit.Test import java.util.* @@ -18,12 +18,43 @@ class AutoDevelop : GenerationReportBase() { } } + @Test + fun testMethodImplementation() { + runReport("SoftwareProject_Impl", SoftwareProjectAI::class) { api, logJson, out -> + val sourceCode = api.implementComponentSpecification( + proxy.fromJson( + SoftwareProjectAI.Project::class.java, + "{\r\n \"name\" : \"SupportAliasBot\",\r\n \"description\" : \"Slack bot to monitor a support alias\",\r\n \"language\" : \"Kotlin\",\r\n \"features\" : [ \"Record all requests tagging an alias in a database\", \"Send a message to a Slack channel when requests are tagged with a specific label\" ],\r\n \"libraries\" : [ \"Gradle\", \"Spring\" ],\r\n \"buildTools\" : [ \"Gradle\" ]\r\n}" + ), + proxy.fromJson( + SoftwareProjectAI.ComponentDetails::class.java, + "{\r\n \"description\" : \"Main class for the SupportAliasBot\",\r\n \"requires\" : [ ],\r\n \"publicProperties\" : [ ],\r\n \"publicMethodSignatures\" : [ \"fun main(args: Array): Unit\" ],\r\n \"language\" : \"Kotlin\",\r\n \"location\" : {\r\n \"file\" : \"src/main/kotlin/com/example/supportaliasbot/SupportAliasBot.kt\"\r\n }\r\n}" + ), + listOf(), + proxy.fromJson( + SoftwareProjectAI.CodeSpecification::class.java, + "{\r\n \"name\" : \"SupportAliasBot\",\r\n \"description\" : \"Slack bot to monitor a support alias\",\r\n \"language\" : \"Kotlin\",\r\n \"features\" : [ \"Record all requests tagging an alias in a database\", \"Send a message to a Slack channel when requests are tagged with a specific label\" ],\r\n \"libraries\" : [ \"Gradle\", \"Spring\" ],\r\n \"buildTools\" : [ \"Gradle\" ]\r\n}" + ), + ) + out(""" + |```${sourceCode.language} + |${sourceCode.code} + |``` + |""".trimMargin()) + } + } + private fun report( api: SoftwareProjectAI, logJson: (Any?) -> Unit, out: (Any?) -> Unit ) { - var description: String + val drafts = 2 + val threads = 7 + proxy.temperature = 0.5 + + @Suppress("JoinDeclarationAndAssignment") + val description: String description = """ | |Slack bot to monitor a support alias @@ -37,7 +68,7 @@ class AutoDevelop : GenerationReportBase() { |Frameworks: Gradle, Spring | """.trimMargin() - description = """ + """ | |Create a website where users can upload stories, share them, and rate them | @@ -49,30 +80,34 @@ class AutoDevelop : GenerationReportBase() { |Frameworks: Gradle, Spring | """.trimMargin() - out(""" + out( + """ | |# Software Project Development Report | |## Description | |``` - |${description.trim() }} + |${description.trim()}} |``` | - |""".trimMargin()) + |""".trimMargin() + ) var project: SoftwareProjectAI.Project? = null var requirements: SoftwareProjectAI.ProjectStatements? = null var projectDesign: SoftwareProjectAI.ProjectDesign? = null var components: Map>? = null - var documents: Map>? = null + var documents: Map>? = + null var tests: Map>? = null var implementations: Map? = null try { project = api.newProject(description.trim()) - out(""" + out( + """ | |Project Name: ${project.name} | @@ -80,64 +115,85 @@ class AutoDevelop : GenerationReportBase() { | |Language: ${project.language} | - |Libraries: ${project.libraries.joinToString(", ")} + |Libraries: ${project.libraries?.joinToString(", ")} | - |Build Tools: ${project.buildTools.joinToString(", ")} + |Build Tools: ${project.buildTools?.joinToString(", ")} | |""".trimMargin() ) logJson(project) requirements = api.getProjectStatements(description.trim(), project) - out(""" + out( + """ | |## Requirements | - |""".trimMargin()) + |""".trimMargin() + ) logJson(requirements) + projectDesign = api.buildProjectDesign(project, requirements) - out(""" + out( + """ | |## Design | - |""".trimMargin()) + |""".trimMargin() + ) logJson(projectDesign) components = - projectDesign.components.map { it to api.buildComponentFileSpecifications(project, requirements, it) } - .toMap() - out(""" + projectDesign.components?.map { + it to (api.buildComponentFileSpecifications( + project, + requirements, + it + )) + }?.toMap() + out( + """ | |## Components | - |""".trimMargin()) + |""".trimMargin() + ) logJson(components) + documents = - projectDesign.documents.map { - it to api.buildDocumentationFileSpecifications( + projectDesign.documents?.map { + it to (api.buildDocumentationFileSpecifications( project, requirements, it ) - }.toMap() - out(""" + ) + }?.toMap() + out( + """ | |## Documents | - |""".trimMargin()) + |""".trimMargin() + ) logJson(documents) - tests = projectDesign.tests.map { it to api.buildTestFileSpecifications(project, requirements, it) }.toMap() - out(""" + tests = projectDesign.tests?.map { + it to (api.buildTestFileSpecifications(project, requirements, it)) + }?.toMap() + out( + """ | |## Tests | - |""".trimMargin()) + |""".trimMargin() + ) logJson(tests) - implementations = parallelImplement(api, project, components, documents, tests, 1, 7) + implementations = parallelImplement(api, project, components, documents, tests, drafts, threads) } catch (e: Exception) { e.printStackTrace() } if (implementations != null) { - val relative = "projects/${project?.name ?: UUID.randomUUID()}.zip" + val name = (project?.name ?: UUID.randomUUID().toString()).replace(Regex("[^a-zA-Z0-9]"), "_") + val relative = "projects/$name.zip" val zipArchiveFile = outputDir.resolve(relative) zipArchiveFile.parentFile.mkdirs() write(zipArchiveFile, implementations) @@ -150,13 +206,13 @@ class AutoDevelop : GenerationReportBase() { | |""".trimMargin() ) - implementations.toList().sortedBy { it.first.fullFilePathName }.forEach { (file, sourceCodes) -> + implementations.toList().sortedBy { it.first.file }.forEach { (file, sourceCodes) -> out( """ | - |### ${file.fullFilePathName} + |### ${file.file} | - |```${sourceCodes!!.language.lowercase()} + |```${sourceCodes!!.language?.lowercase()} |${sourceCodes.code} |``` | diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt index b8ea37ae..6f9154b3 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/AutoNews.kt @@ -59,7 +59,7 @@ class AutoNews : GenerationReportBase() { | |![${story.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( story.image.detailedCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt index a3b016b2..d2958f86 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ChildrensStory.kt @@ -149,7 +149,7 @@ class ChildrensStory : GenerationReportBase(){ | |![${it.image.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( it.image.detailedCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt index c2210d4f..8359b59f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ComicBook.kt @@ -84,7 +84,7 @@ class ComicBook: GenerationReportBase() { | |![${character.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( character.image.detailedCaption, resolution = 512 )[0] @@ -132,7 +132,7 @@ class ComicBook: GenerationReportBase() { | |![${page.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( page.image.detailedCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt index 74eb5996..55368e00 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/FamilyGuyWriter.kt @@ -54,7 +54,7 @@ class FamilyGuyWriter : GenerationReportBase(){ """ |![${imageCaption.caption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( imageCaption.caption, resolution = 512 )[0] @@ -74,7 +74,7 @@ class FamilyGuyWriter : GenerationReportBase(){ """ |![${cutaway.imageCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( cutaway.imageCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt index ccb938f8..7c4fc49f 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/ImageTest.kt @@ -7,7 +7,7 @@ import java.io.File class ImageTest: GenerationReportBase() { @Test fun imageGenerationTest() { - val image = proxy.api.text_to_image("Hello World")[0] + val image = proxy.api.render("Hello World")[0] // Write the image to a file val file = File.createTempFile("image", ".png") javax.imageio.ImageIO.write(image, "png", file) diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt index 1c06cfb3..c2bd932c 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/RecipeBook.kt @@ -78,7 +78,7 @@ class RecipeBook : GenerationReportBase() { | |![${recipe.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( recipe.image.detailedCaption, resolution = 512 )[0] diff --git a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt index 3cd1ee49..cfab3eb8 100644 --- a/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt +++ b/src/test/kotlin/com/github/simiacryptus/aicoder/proxy/VideoGame.kt @@ -90,7 +90,7 @@ class VideoGame : GenerationReportBase() { | |![${character.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( character.image.detailedCaption, resolution = 512 )[0] @@ -138,7 +138,7 @@ class VideoGame : GenerationReportBase() { | |![${level.image!!.detailedCaption}](${ writeImage( - proxy.api.text_to_image( + proxy.api.render( level.image.detailedCaption, resolution = 512 )[0]