From 4537e4759bda66f163a409b356f7240b7de838c0 Mon Sep 17 00:00:00 2001 From: Andrew Charneski Date: Thu, 12 Jan 2023 09:26:29 -0500 Subject: [PATCH] 0.1.8 (#6) --- CHANGELOG.md | 12 + gradle.properties | 2 +- .../aicoder/ComputerLanguage.java | 34 +- .../simiacryptus/aicoder/EditorMenu.java | 662 +++++++++--------- .../aicoder/TextReplacementAction.java | 63 -- .../aicoder/config/AppSettingsComponent.java | 66 +- .../aicoder/config/AppSettingsState.java | 22 +- .../config/SimpleSettingsComponent.java | 24 +- .../aicoder/openai/CompletionRequest.java | 45 +- .../aicoder/openai/CompletionResponse.java | 2 +- .../simiacryptus/aicoder/openai/LogProbs.java | 4 +- .../simiacryptus/aicoder/openai/OpenAI.java | 185 ----- .../aicoder/openai/OpenAI_API.java | 290 ++++++++ .../translate/BaseTranslationRequest.java | 42 +- .../openai/translate/TranslationRequest.java | 18 +- .../translate/TranslationRequest_XML.java | 4 +- .../aicoder/psi/PsiClassContext.java | 4 +- .../aicoder/psi/PsiMarkdownContext.java | 4 +- .../simiacryptus/aicoder/psi/PsiUtil.java | 28 +- .../aicoder/text/StringTools.java | 150 ---- .../aicoder/{text => util}/BlockComment.java | 24 +- .../aicoder/{text => util}/IndentedText.java | 24 +- .../aicoder/{text => util}/LineComment.java | 12 +- .../aicoder/util/StringTools.java | 270 +++++++ .../aicoder/{ => util}/StyleUtil.java | 68 +- .../aicoder/{text => util}/TextBlock.java | 10 +- .../{text => util}/TextBlockFactory.java | 4 +- .../aicoder/util/TextReplacementAction.java | 77 ++ .../simiacryptus/aicoder/util/UITools.java | 224 ++++++ 29 files changed, 1447 insertions(+), 927 deletions(-) delete mode 100644 src/main/java/com/github/simiacryptus/aicoder/TextReplacementAction.java delete mode 100644 src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI.java create mode 100644 src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI_API.java delete mode 100644 src/main/java/com/github/simiacryptus/aicoder/text/StringTools.java rename src/main/java/com/github/simiacryptus/aicoder/{text => util}/BlockComment.java (64%) rename src/main/java/com/github/simiacryptus/aicoder/{text => util}/IndentedText.java (68%) rename src/main/java/com/github/simiacryptus/aicoder/{text => util}/LineComment.java (75%) create mode 100644 src/main/java/com/github/simiacryptus/aicoder/util/StringTools.java rename src/main/java/com/github/simiacryptus/aicoder/{ => util}/StyleUtil.java (61%) rename src/main/java/com/github/simiacryptus/aicoder/{text => util}/TextBlock.java (60%) rename src/main/java/com/github/simiacryptus/aicoder/{text => util}/TextBlockFactory.java (62%) create mode 100644 src/main/java/com/github/simiacryptus/aicoder/util/TextReplacementAction.java create mode 100644 src/main/java/com/github/simiacryptus/aicoder/util/UITools.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 989644cd..8b16c74e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ ## [Unreleased] +## [0.1.8] + +### Added +- Asynchronous operations now include modal progress +- Added "retry last" and generic "append" and "insert" operations +- Added SCSS support + +## [0.1.7] + +### Added +- All API calls are handled asynchronously - no UI thread blocking! + ## [0.1.6] ### Added diff --git a/gradle.properties b/gradle.properties index 291d12ea..c375c43e 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,7 +4,7 @@ pluginGroup = com.github.simiacryptus pluginName = intellij-aicoder pluginRepositoryUrl = https://github.com/SimiaCryptus/intellij-aicoder # SemVer format -> https://semver.org -pluginVersion = 0.1.6 +pluginVersion = 0.1.8 # Supported build number ranges and IntelliJ Platform versions -> https://plugins.jetbrains.com/docs/intellij/build-number-ranges.html pluginSinceBuild = 203 diff --git a/src/main/java/com/github/simiacryptus/aicoder/ComputerLanguage.java b/src/main/java/com/github/simiacryptus/aicoder/ComputerLanguage.java index ebe781c2..d0e988b4 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/ComputerLanguage.java +++ b/src/main/java/com/github/simiacryptus/aicoder/ComputerLanguage.java @@ -1,8 +1,8 @@ package com.github.simiacryptus.aicoder; -import com.github.simiacryptus.aicoder.text.BlockComment; -import com.github.simiacryptus.aicoder.text.LineComment; -import com.github.simiacryptus.aicoder.text.TextBlockFactory; +import com.github.simiacryptus.aicoder.util.BlockComment; +import com.github.simiacryptus.aicoder.util.LineComment; +import com.github.simiacryptus.aicoder.util.TextBlockFactory; import org.jetbrains.annotations.Nullable; import java.util.Arrays; @@ -21,6 +21,9 @@ public enum ComputerLanguage { .setBlockComments(new BlockComment.Factory("/*", "", "*/")) .setDocComments(new BlockComment.Factory("/**", "*", "*/")) .setFileExtensions("cpp")), + Bash(new Configuration() + .setLineComments(new LineComment.Factory("#")) + .setFileExtensions("sh")), Markdown(new Configuration() .setDocumentationStyle("Markdown") .setLineComments(new BlockComment.Factory("")) @@ -42,9 +45,6 @@ public enum ComputerLanguage { .setBlockComments(new BlockComment.Factory("/*", "", "*/")) .setDocComments(new BlockComment.Factory("/**", "*", "*/")) .setFileExtensions("basic", "bs")), - Bash(new Configuration() - .setLineComments(new LineComment.Factory("#")) - .setFileExtensions("sh")), C(new Configuration() .setDocumentationStyle("Doxygen") .setLineComments(new LineComment.Factory("//")) @@ -221,6 +221,12 @@ public enum ComputerLanguage { .setBlockComments(new BlockComment.Factory("/*", "", "*/")) .setDocComments(new BlockComment.Factory("/**", "*", "*/")) .setFileExtensions("scheme")), + SCSS(new Configuration() + .setDocumentationStyle("SCSS") + .setLineComments(new LineComment.Factory("//")) + .setBlockComments(new BlockComment.Factory("/*", "", "*/")) + .setDocComments(new LineComment.Factory("///")) + .setFileExtensions("scss")), SQL(new Configuration() .setLineComments(new LineComment.Factory("--")) .setBlockComments(new BlockComment.Factory("/*", "", "*/")) @@ -259,7 +265,7 @@ public enum ComputerLanguage { .setLineComments(new LineComment.Factory("#")) .setFileExtensions("zsh")); - public final List extensions; + public final List extensions; public final String docStyle; public final TextBlockFactory lineComment; public final TextBlockFactory blockComment; @@ -274,11 +280,11 @@ public enum ComputerLanguage { } @Nullable - public static ComputerLanguage findByExtension(String extension) { + public static ComputerLanguage findByExtension(CharSequence extension) { return Arrays.stream(values()).filter(x -> x.extensions.contains(extension)).findAny().orElse(null); } - public String getMultilineCommentSuffix() { + public CharSequence getMultilineCommentSuffix() { if (docComment instanceof BlockComment.Factory) { return ((BlockComment.Factory) docComment).blockSuffix; } @@ -286,14 +292,14 @@ public String getMultilineCommentSuffix() { } public TextBlockFactory getCommentModel(String text) { - if(docComment.looksLike(text)) return docComment; - if(blockComment.looksLike(text)) return blockComment; + if (docComment.looksLike(text)) return docComment; + if (blockComment.looksLike(text)) return blockComment; return lineComment; } static class Configuration { private String documentationStyle = ""; - private String[] fileExtensions = new String[] {}; + private CharSequence[] fileExtensions = new CharSequence[]{}; private TextBlockFactory lineComments = null; private TextBlockFactory blockComments = null; private TextBlockFactory docComments = null; @@ -307,11 +313,11 @@ public Configuration setDocumentationStyle(String documentationStyle) { return this; } - public String[] getFileExtensions() { + public CharSequence[] getFileExtensions() { return fileExtensions; } - public Configuration setFileExtensions(String... fileExtensions) { + public Configuration setFileExtensions(CharSequence... fileExtensions) { this.fileExtensions = fileExtensions; return this; } diff --git a/src/main/java/com/github/simiacryptus/aicoder/EditorMenu.java b/src/main/java/com/github/simiacryptus/aicoder/EditorMenu.java index 8b977ff7..cc430afa 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/EditorMenu.java +++ b/src/main/java/com/github/simiacryptus/aicoder/EditorMenu.java @@ -1,53 +1,44 @@ package com.github.simiacryptus.aicoder; import com.github.simiacryptus.aicoder.config.AppSettingsState; +import com.github.simiacryptus.aicoder.openai.CompletionRequest; import com.github.simiacryptus.aicoder.openai.ModerationException; import com.github.simiacryptus.aicoder.psi.PsiClassContext; import com.github.simiacryptus.aicoder.psi.PsiMarkdownContext; import com.github.simiacryptus.aicoder.psi.PsiUtil; -import com.github.simiacryptus.aicoder.text.IndentedText; -import com.github.simiacryptus.aicoder.text.StringTools; -import com.github.simiacryptus.aicoder.text.TextBlockFactory; +import com.github.simiacryptus.aicoder.util.*; +import com.intellij.core.CoreBundle; import com.intellij.openapi.actionSystem.ActionGroup; import com.intellij.openapi.actionSystem.AnAction; import com.intellij.openapi.actionSystem.AnActionEvent; import com.intellij.openapi.actionSystem.CommonDataKeys; -import com.intellij.openapi.command.WriteCommandAction; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.editor.Caret; import com.intellij.openapi.editor.CaretModel; import com.intellij.openapi.editor.Document; import com.intellij.openapi.editor.Editor; import com.intellij.openapi.ide.CopyPasteManager; +import com.intellij.openapi.util.TextRange; import com.intellij.openapi.vfs.VirtualFile; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiFile; +import org.jetbrains.annotations.Nls; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import javax.swing.*; import java.awt.datatransfer.DataFlavor; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; public class EditorMenu extends ActionGroup { - private static final Logger log = Logger.getInstance(EditorMenu.class); - - public static void handle(@NotNull Throwable ex) { - if (!(ex instanceof ModerationException)) log.error(ex); - JOptionPane.showMessageDialog(null, ex.getMessage(), "Warning", JOptionPane.WARNING_MESSAGE); - } - public static boolean hasSelection(@NotNull AnActionEvent e) { - Caret caret = e.getData(CommonDataKeys.CARET); - return null != caret && caret.hasSelection(); - } + private static final Logger log = Logger.getInstance(EditorMenu.class); + public static final @NotNull + @Nls String DEFAULT_ACTION_MESSAGE = CoreBundle.message("command.name.undefined"); /** * This method is used to get the children of the action. @@ -73,6 +64,18 @@ public static boolean hasSelection(@NotNull AnActionEvent e) { ArrayList children = new ArrayList<>(); ComputerLanguage language = ComputerLanguage.findByExtension(extension); + Caret caret = e.getData(CommonDataKeys.CARET); + + if (null != caret) { + addIfNotNull(children, redoLast()); + if (!caret.hasSelection()) { + children.add(genericInsert()); + } else { + children.add(genericAppend()); + } + } + + if (language != null) { addIfNotNull(children, rewordCommentAction(e, language, inputHumanLanguage)); @@ -85,30 +88,34 @@ public static boolean hasSelection(@NotNull AnActionEvent e) { children.add(pasteAction(language.name())); } - if (!language.docStyle.isEmpty()) children.add(docAction(extension, language)); + if (language.docStyle.length() > 0) { + children.add(docAction(extension, language)); + } if (language == ComputerLanguage.Markdown) { addIfNotNull(children, markdownListAction(e)); addIfNotNull(children, markdownNewTableRowsAction(e)); addIfNotNull(children, markdownNewTableColsAction(e)); - addIfNotNull(children, markdownNewTableColsAction2(e)); + addIfNotNull(children, markdownNewTableColAction(e)); } - if (hasSelection(e)) { - children.add(customEdit(language.name())); - children.add(recentEdits(language.name())); - switch (language) { - case Markdown: - addIfNotNull(children, markdownContextAction(e, inputHumanLanguage)); - break; - default: - addIfNotNull(children, psiClassContextAction(e, language, inputHumanLanguage)); - break; + if (null != caret) { + if (caret.hasSelection()) { + children.add(customEdit(language.name())); + children.add(recentEdits(language.name())); + switch (language) { + case Markdown: + addIfNotNull(children, markdownContextAction(e, inputHumanLanguage)); + break; + default: + addIfNotNull(children, psiClassContextAction(e, language, inputHumanLanguage)); + break; + } + children.add(describeAction(outputHumanLanguage, language)); + children.add(addCodeCommentsAction(outputHumanLanguage, language)); + children.add(fromHumanLanguageAction(inputHumanLanguage, language)); + children.add(toHumanLanguageAction(outputHumanLanguage, language)); } - children.add(describeAction(outputHumanLanguage, language)); - children.add(addCodeCommentsAction(outputHumanLanguage, language)); - children.add(fromHumanLanguageAction(inputHumanLanguage, language)); - children.add(toHumanLanguageAction(outputHumanLanguage, language)); } } @@ -116,83 +123,134 @@ public static boolean hasSelection(@NotNull AnActionEvent e) { } @NotNull - public static TextReplacementAction toHumanLanguageAction(String outputHumanLanguage, ComputerLanguage language) { + protected AnAction toHumanLanguageAction(String outputHumanLanguage, ComputerLanguage language) { String computerLanguage = language.name(); - return TextReplacementAction.create("_To " + outputHumanLanguage, String.format("Describe %s -> %s", outputHumanLanguage, computerLanguage), null, (event, string) -> { - AppSettingsState settings = AppSettingsState.getInstance(); - return settings.createTranslationRequest() - .setInstruction(getInstruction(settings.style, "Describe this code")) - .setInputText(string) - .setInputType(computerLanguage) - .setInputAttribute("type", "input") - .setOutputType(outputHumanLanguage.toLowerCase()) - .setOutputAttrute("type", "output") - .setOutputAttrute("style", settings.style) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); - }); + String description = String.format("Describe %s -> %s", outputHumanLanguage, computerLanguage); + return TextReplacementAction.create("_To " + outputHumanLanguage, description, null, + (event, string) -> { + AppSettingsState settings = AppSettingsState.getInstance(); + return settings.createTranslationRequest() + .setInstruction(UITools.getInstruction("Describe this code")) + .setInputText(string) + .setInputType(computerLanguage) + .setInputAttribute("type", "input") + .setOutputType(outputHumanLanguage.toLowerCase()) + .setOutputAttrute("type", "output") + .setOutputAttrute("style", settings.style) + .buildCompletionRequest(); + }); } @NotNull - public static TextReplacementAction fromHumanLanguageAction(String inputHumanLanguage, ComputerLanguage language) { + protected AnAction fromHumanLanguageAction(String inputHumanLanguage, ComputerLanguage language) { String computerLanguage = language.name(); - return TextReplacementAction.create("_From " + inputHumanLanguage, String.format("Implement %s -> %s", inputHumanLanguage, computerLanguage), null, (event, string) -> { - return AppSettingsState.getInstance().createTranslationRequest() - .setInputType(inputHumanLanguage.toLowerCase()) - .setOutputType(computerLanguage) - .setInstruction("Implement this specification") - .setInputAttribute("type", "input") - .setOutputAttrute("type", "output") - .setInputText(string) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); - }); + String description = String.format("Implement %s -> %s", inputHumanLanguage, computerLanguage); + return TextReplacementAction.create("_From " + inputHumanLanguage, description, null, (event, string) -> + AppSettingsState.getInstance().createTranslationRequest() + .setInputType(inputHumanLanguage.toLowerCase()) + .setOutputType(computerLanguage) + .setInstruction("Implement this specification") + .setInputAttribute("type", "input") + .setOutputAttrute("type", "output") + .setInputText(string) + .buildCompletionRequest()); } @NotNull - public static TextReplacementAction addCodeCommentsAction(String outputHumanLanguage, ComputerLanguage language) { + protected AnAction addCodeCommentsAction(CharSequence outputHumanLanguage, ComputerLanguage language) { String computerLanguage = language.name(); return TextReplacementAction.create("Add Code _Comments", "Add Code Comments", null, (event, string) -> { AppSettingsState settings = AppSettingsState.getInstance(); return settings.createTranslationRequest() .setInputType(computerLanguage) .setOutputType(computerLanguage) - .setInstruction(getInstruction(settings.style, "Rewrite to include detailed " + outputHumanLanguage + " code comments for every line")) + .setInstruction(UITools.getInstruction("Rewrite to include detailed " + outputHumanLanguage + " code comments for every line")) .setInputAttribute("type", "uncommented") .setOutputAttrute("type", "commented") .setOutputAttrute("style", settings.style) .setInputText(string) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); + .buildCompletionRequest(); }); } @NotNull - public static TextReplacementAction describeAction(String outputHumanLanguage, ComputerLanguage language) { - return TextReplacementAction.create("_Describe Code and Prepend Comment", "Add JavaDoc Comments", null, (event, inputString) -> { - AppSettingsState settings = AppSettingsState.getInstance(); - String indent = getIndent(event.getData(CommonDataKeys.CARET)); - String description = settings.createTranslationRequest() - .setInputType(language.name()) - .setOutputType(outputHumanLanguage) - .setInstruction(getInstruction(settings.style, "Explain this " + language.name() + " in " + outputHumanLanguage)) - .setInputAttribute("type", "code") - .setOutputAttrute("type", "description") - .setOutputAttrute("style", settings.style) - .setInputText(IndentedText.fromString(inputString).getTextBlock().trim()) - .buildCompletionRequest() - .complete(indent); - return "\n" + indent + language.blockComment.fromString(StringTools.lineWrapping(description.trim(), 120)).withIndent(indent) + "\n" + indent + inputString; + protected AnAction describeAction(String outputHumanLanguage, ComputerLanguage language) { + return TextReplacementAction.create("_Describe Code and Prepend Comment", "Add JavaDoc Comments", null, new TextReplacementAction.ActionTextEditorFunction() { + @Override + public CompletionRequest apply(AnActionEvent event, String inputString) throws IOException, ModerationException { + AppSettingsState settings = AppSettingsState.getInstance(); + return settings.createTranslationRequest() + .setInputType(language.name()) + .setOutputType(outputHumanLanguage) + .setInstruction(UITools.getInstruction("Explain this " + language.name() + " in " + outputHumanLanguage)) + .setInputAttribute("type", "code") + .setOutputAttrute("type", "description") + .setOutputAttrute("style", settings.style) + .setInputText(IndentedText.fromString(inputString).getTextBlock().trim()) + .buildCompletionRequest(); + } + + @Override + public CharSequence postTransform(AnActionEvent event, CharSequence prompt, CharSequence completion) { + CharSequence indent = UITools.getIndent(event); + String wrapping = StringTools.lineWrapping(completion.toString().trim(), 120); + return "\n" + indent + language.blockComment.fromString(wrapping).withIndent(indent) + "\n" + indent + prompt; + } }); } - public static String getInstruction(String style, String instruction) { - if (style.isEmpty()) return instruction; - return String.format("%s (%s)", instruction, style); + @NotNull + protected AnAction genericInsert() { + return new AnAction("_Insert Text", "Insert Text", null) { + @Override + public void actionPerformed(@NotNull AnActionEvent event) { + Caret caret = event.getData(CommonDataKeys.CARET); + Document document = caret.getEditor().getDocument(); + int caretPosition = caret.getOffset(); + CharSequence before = StringTools.getSuffixForContext(document.getText(new TextRange(0, caretPosition))); + CharSequence after = StringTools.getPrefixForContext(document.getText(new TextRange(caretPosition, document.getTextLength()))); + AppSettingsState settings = AppSettingsState.getInstance(); + CompletionRequest completionRequest = settings.createCompletionRequest() + .appendPrompt(before) + .setSuffix(after); + UITools.redoableRequest(completionRequest, "", event, complete -> { + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.insertString(editor.getDocument(), caretPosition, complete); + }); + } + }; } @NotNull - public static TextReplacementAction customEdit(String computerLanguage) { + protected AnAction genericAppend() { + return new AnAction("_Append Text", "Append Text", null) { + @Override + public void actionPerformed(@NotNull AnActionEvent event) { + Caret caret = event.getData(CommonDataKeys.CARET); + CharSequence before = caret.getSelectedText(); + AppSettingsState settings = AppSettingsState.getInstance(); + CompletionRequest completionRequest = settings.createCompletionRequest() + .appendPrompt(before); + UITools.redoableRequest(completionRequest, "", event, complete -> { + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.insertString(editor.getDocument(), caret.getSelectionEnd(), complete); + }); + } + }; + } + + @Nullable + protected AnAction redoLast() { + if(UITools.retry.isEmpty()) return null; + return new AnAction("_Redo Last", "Redo last", null) { + @Override + public void actionPerformed(@NotNull AnActionEvent event) { + UITools.retry.pop().run(); + } + }; + } + + protected AnAction customEdit(String computerLanguage) { return TextReplacementAction.create("_Edit...", "Edit...", null, (event, string) -> { String instruction = JOptionPane.showInputDialog(null, "Instruction:", "Edit Code", JOptionPane.QUESTION_MESSAGE); AppSettingsState settings = AppSettingsState.getInstance(); @@ -204,8 +262,7 @@ public static TextReplacementAction customEdit(String computerLanguage) { .setInputAttribute("type", "before") .setOutputAttrute("type", "after") .setInputText(IndentedText.fromString(string).getTextBlock()) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); + .buildCompletionRequest(); }); } @@ -217,7 +274,7 @@ public static TextReplacementAction customEdit(String computerLanguage) { * @return a new ActionGroup for recent edits */ @NotNull - public static ActionGroup recentEdits(String computerLanguage) { + protected ActionGroup recentEdits(String computerLanguage) { return new ActionGroup("Recent Edits", true) { @Override public AnAction @NotNull [] getChildren(@Nullable AnActionEvent e) { @@ -229,39 +286,36 @@ public static ActionGroup recentEdits(String computerLanguage) { } @NotNull - public static AnAction docAction(String extension, ComputerLanguage language) { + protected AnAction docAction(String extension, ComputerLanguage language) { return new AnAction("_Add " + language.docStyle + " Comments", "Add " + language.docStyle + " Comments", null) { @Override public void actionPerformed(@NotNull final AnActionEvent event) { - try { - Caret caret = event.getData(CommonDataKeys.CARET); - PsiFile psiFile = event.getRequiredData(CommonDataKeys.PSI_FILE); - PsiElement smallestIntersectingMethod = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd()); - if (null == smallestIntersectingMethod) return; - AppSettingsState settings = AppSettingsState.getInstance(); - String code = smallestIntersectingMethod.getText(); - IndentedText indentedInput = IndentedText.fromString(code); - String indent = indentedInput.getIndent(); - String rawDocString = settings.createTranslationRequest() - .setInputType(extension) - .setOutputType(extension) - .setInstruction(getInstruction(settings.style, "Rewrite to include detailed " + language.docStyle)) - .setInputAttribute("type", "uncommented") - .setOutputAttrute("type", "commented") - .setOutputAttrute("style", settings.style) - .setInputText(indentedInput.getTextBlock()) - .buildCompletionRequest() - .addStops(language.getMultilineCommentSuffix()) - .complete("").trim(); - final String newText = language.docComment.fromString(rawDocString).withIndent(indent) + "\n" + indent + StringTools.trimPrefix(indentedInput.toString()); - //language.docComment.fromString(rawDocString).withIndent(indent).toString() - WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { - final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); - editor.getDocument().replaceString(smallestIntersectingMethod.getTextRange().getStartOffset(), smallestIntersectingMethod.getTextRange().getEndOffset(), newText); - }); - } catch (ModerationException | IOException ex) { - handle(ex); - } + Caret caret = event.getData(CommonDataKeys.CARET); + PsiFile psiFile = event.getRequiredData(CommonDataKeys.PSI_FILE); + PsiElement smallestIntersectingMethod = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd()); + if (null == smallestIntersectingMethod) return; + AppSettingsState settings = AppSettingsState.getInstance(); + String code = smallestIntersectingMethod.getText(); + IndentedText indentedInput = IndentedText.fromString(code); + CharSequence indent = indentedInput.getIndent(); + CompletionRequest completionRequest = settings.createTranslationRequest() + .setInputType(extension) + .setOutputType(extension) + .setInstruction(UITools.getInstruction("Rewrite to include detailed " + language.docStyle)) + .setInputAttribute("type", "uncommented") + .setOutputAttrute("type", "commented") + .setOutputAttrute("style", settings.style) + .setInputText(indentedInput.getTextBlock()) + .buildCompletionRequest() + .addStops(language.getMultilineCommentSuffix()); + int startOffset = smallestIntersectingMethod.getTextRange().getStartOffset(); + int endOffset = smallestIntersectingMethod.getTextRange().getEndOffset(); + UITools.redoableRequest(completionRequest, "", event, (CharSequence docString) -> { + TextBlock reindented = language.docComment.fromString(docString.toString().trim()).withIndent(indent); + final CharSequence newText = reindented + "\n" + indent + StringTools.trimPrefix(indentedInput.toString()); + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.replaceString(editor.getDocument(), startOffset, endOffset, newText); + }); } }; } @@ -273,7 +327,7 @@ public void actionPerformed(@NotNull final AnActionEvent event) { * @return a {@link TextReplacementAction} that pastes the contents of the clipboard into the given language */ @NotNull - public static TextReplacementAction pasteAction(@NotNull String language) { + protected AnAction pasteAction(@NotNull CharSequence language) { return TextReplacementAction.create("_Paste", "Paste", null, (event, string) -> { String text = CopyPasteManager.getInstance().getContents(DataFlavor.stringFlavor).toString().trim(); return AppSettingsState.getInstance().createTranslationRequest() @@ -283,19 +337,12 @@ public static TextReplacementAction pasteAction(@NotNull String language) { .setInputAttribute("language", "autodetect") .setOutputAttrute("language", language) .setInputText(text) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); + .buildCompletionRequest(); }); } - public static String getIndent(Caret caret) { - if (null == caret) return ""; - Document document = caret.getEditor().getDocument(); - return IndentedText.fromString(document.getText().split("\n")[document.getLineNumber(caret.getSelectionStart())]).getIndent(); - } - @NotNull - public static TextReplacementAction customEdit(String computerLanguage, String instruction) { + protected AnAction customEdit(CharSequence computerLanguage, CharSequence instruction) { return TextReplacementAction.create(instruction, instruction, null, (event, string) -> { AppSettingsState settings = AppSettingsState.getInstance(); settings.addInstructionToHistory(instruction); @@ -306,35 +353,48 @@ public static TextReplacementAction customEdit(String computerLanguage, String i .setInputAttribute("type", "before") .setOutputAttrute("type", "after") .setInputText(IndentedText.fromString(string).getTextBlock()) - .buildCompletionRequest() - .complete(getIndent(event.getData(CommonDataKeys.CARET))); + .buildCompletionRequest(); }); } @Nullable - public static AnAction markdownListAction(@NotNull AnActionEvent e) { + protected AnAction markdownListAction(@NotNull AnActionEvent e) { try { Caret caret = e.getData(CommonDataKeys.CARET); - if(null == caret) return null; + if (null == caret) return null; PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); - if(null == psiFile) return null; + if (null == psiFile) return null; PsiElement list = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd(), "MarkdownListImpl"); if (null == list) return null; return new AnAction("Add _List Items", "Add list items", null) { @Override public void actionPerformed(@NotNull AnActionEvent event) { AppSettingsState settings = AppSettingsState.getInstance(); - List items = trim(PsiUtil.getAll(list, "MarkdownListItemImpl").stream().map(item -> PsiUtil.getAll(item, "MarkdownParagraphImpl").get(0).getText()).collect(Collectors.toList()), 10, false); - String indent = getIndent(caret); - String n = Integer.toString(items.size() * 2); - List newItems = getNewItems(settings, items, n); - String strippedList = Arrays.stream(list.getText().split("\n")).map(String::trim).filter(x -> !x.isEmpty()).collect(Collectors.joining("\n")); - String bulletString = Stream.of("- [ ] ", "- ", "* ") - .filter(strippedList::startsWith).findFirst().orElse("1. "); - String itemText = indent + newItems.stream().map(x -> bulletString + x).collect(Collectors.joining("\n" + indent)); - WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { + List items = StringTools.trim(PsiUtil.getAll(list, "MarkdownListItemImpl") + .stream().map(item -> PsiUtil.getAll(item, "MarkdownParagraphImpl").get(0).getText()).collect(Collectors.toList()), 10, false); + CharSequence indent = UITools.getIndent(caret); + CharSequence n = Integer.toString(items.size() * 2); + int endOffset = list.getTextRange().getEndOffset(); + String listPrefix = "* "; + CompletionRequest completionRequest = settings.createTranslationRequest() + .setInstruction(UITools.getInstruction("List " + n + " items")) + .setInputType("instruction") + .setInputText("List " + n + " items") + .setOutputType("list") + .setOutputAttrute("style", settings.style) + .buildCompletionRequest() + .appendPrompt(items.stream().map(x2 -> listPrefix + x2).collect(Collectors.joining("\n")) + "\n" + listPrefix); + UITools.redoableRequest(completionRequest, "", event, complete -> { + List newItems = Arrays.stream(complete.toString().split("\n")).map(String::trim) + .filter(x1 -> x1 != null && x1.length() > 0).map(x1 -> StringTools.stripPrefix(x1, listPrefix)).collect(Collectors.toList()); + String strippedList = Arrays.stream(list.getText().split("\n")) + .map(String::trim).filter(x -> x.length() > 0).collect(Collectors.joining("\n")); + String bulletString = Stream.of("- [ ] ", "- ", "* ") + .filter(strippedList::startsWith).findFirst().orElse("1. "); + CharSequence itemText = indent + newItems.stream().map(x -> bulletString + x) + .collect(Collectors.joining("\n" + indent)); final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); - editor.getDocument().insertString(list.getTextRange().getEndOffset(), "\n" + itemText); + return UITools.insertString(editor.getDocument(), endOffset, "\n" + itemText); }); } }; @@ -344,135 +404,93 @@ public void actionPerformed(@NotNull AnActionEvent event) { } } - @NotNull - private static List trim(List items, int max, boolean preserveHead) { - items = new ArrayList<>(items); - Random random = new Random(); - while (items.size() > max) { - int index = random.nextInt(items.size()); - if (preserveHead && index == 0) continue; - items.remove(index); - } - return items; - } - - @NotNull - private static List getNewItems(AppSettingsState settings, List items, String n) { - String listPrefix = "* "; - String complete; - try { - complete = settings.createTranslationRequest() - .setInstruction(getInstruction(settings.style, "List " + n + " items")) - .setInputType("instruction") - .setInputText("List " + n + " items") - .setOutputType("list") - .setOutputAttrute("style", settings.style) - .buildCompletionRequest() - .appendPrompt(items.stream().map(x -> listPrefix + x).collect(Collectors.joining("\n")) + "\n" + listPrefix) - .complete(""); - } catch (IOException ex) { - throw new RuntimeException(ex); - } catch (ModerationException ex) { - throw new RuntimeException(ex); - } - - return Arrays.stream(complete.split("\n")).map(String::trim).filter(x -> x != null && !x.isEmpty()).map(x -> StringTools.stripPrefix(x, listPrefix)).collect(Collectors.toList()); - } - + /** + * This method creates an action to add new columns to a Markdown table. + * + * @param e The action event + * @return An action to add new columns to a Markdown table, or null if the action cannot be created + */ @Nullable - public static AnAction markdownNewTableColsAction(@NotNull AnActionEvent e) { + protected AnAction markdownNewTableColsAction(@NotNull AnActionEvent e) { Caret caret = e.getData(CommonDataKeys.CARET); if (null == caret) return null; PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); if (null == psiFile) return null; PsiElement table = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd(), "MarkdownTableImpl"); if (null == table) return null; - List rows = Arrays.asList(transposeMarkdownTable(PsiUtil.getAll(table, "MarkdownTableRowImpl").stream().map(PsiElement::getText).collect(Collectors.joining("\n")), false, false).split("\n")); - String n = Integer.toString(rows.size() * 2); + List rows = Arrays.asList(StringTools.transposeMarkdownTable(PsiUtil.getAll(table, "MarkdownTableRowImpl") + .stream().map(PsiElement::getText).collect(Collectors.joining("\n")), false, false).split("\n")); + CharSequence n = Integer.toString(rows.size() * 2); return new AnAction("Add _Table Columns", "Add table columns", null) { @Override public void actionPerformed(@NotNull AnActionEvent event) { + CharSequence originalText = table.getText(); AppSettingsState settings = AppSettingsState.getInstance(); - String indent = getIndent(caret); - List newRows = newRows(settings, n, rows, ""); - String newTableTxt = transposeMarkdownTable(Stream.concat(rows.stream(), newRows.stream()).collect(Collectors.joining("\n")), false, true); - WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { - final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); - editor.getDocument().replaceString(table.getTextRange().getStartOffset(), table.getTextRange().getEndOffset(), newTableTxt.replace("\n", "\n" + indent)); - }); + CharSequence indent = UITools.getIndent(caret); + UITools.redoableRequest(newRowsRequest(settings, n, rows, ""), + "", + event, + (CharSequence complete) -> { + List newRows = Arrays.stream(("" + complete).split("\n")).map(String::trim) + .filter(x -> x.length() > 0).collect(Collectors.toList()); + String newTableTxt = StringTools.transposeMarkdownTable(Stream.concat(rows.stream(), newRows.stream()) + .collect(Collectors.joining("\n")), false, true); + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.replaceString( + editor.getDocument(), + table.getTextRange().getStartOffset(), + table.getTextRange().getEndOffset(), + newTableTxt.replace("\n", "\n" + indent)); + }); } }; } + /** + * Creates an action to add a new column to a Markdown table. + * + * @param e The action event. + * @return An action to add a new column to a Markdown table, or null if the action cannot be created. + */ @Nullable - public static AnAction markdownNewTableColsAction2(@NotNull AnActionEvent e) { + protected AnAction markdownNewTableColAction(@NotNull AnActionEvent e) { Caret caret = e.getData(CommonDataKeys.CARET); if (null == caret) return null; PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); if (null == psiFile) return null; PsiElement table = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd(), "MarkdownTableImpl"); if (null == table) return null; - List rows = Arrays.asList(transposeMarkdownTable(PsiUtil.getAll(table, "MarkdownTableRowImpl").stream().map(PsiElement::getText).collect(Collectors.joining("\n")), false, false).split("\n")); - String n = Integer.toString(rows.size() * 2); + List rows = Arrays.asList(StringTools.transposeMarkdownTable(PsiUtil.getAll(table, "MarkdownTableRowImpl") + .stream().map(PsiElement::getText).collect(Collectors.joining("\n")), false, false).split("\n")); + CharSequence n = Integer.toString(rows.size() * 2); return new AnAction("Add Table _Column...", "Add table column...", null) { @Override public void actionPerformed(@NotNull AnActionEvent event) { AppSettingsState settings = AppSettingsState.getInstance(); - String indent = getIndent(caret); - String columnName = JOptionPane.showInputDialog(null, "Column Name:", "Add Column", JOptionPane.QUESTION_MESSAGE); - List newRows = newRows(settings, n, rows, "| " + columnName + " | "); - String newTableTxt = transposeMarkdownTable(Stream.concat(rows.stream(), newRows.stream()).collect(Collectors.joining("\n")), false, true); - WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { - final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); - editor.getDocument().replaceString(table.getTextRange().getStartOffset(), table.getTextRange().getEndOffset(), newTableTxt.replace("\n", "\n" + indent)); - }); + CharSequence indent = UITools.getIndent(caret); + CharSequence columnName = JOptionPane.showInputDialog(null, "Column Name:", "Add Column", JOptionPane.QUESTION_MESSAGE); + UITools.redoableRequest( + newRowsRequest(settings, n, rows, "| " + columnName + " | "), + "", + event, + (CharSequence complete) -> { + List newRows = Arrays.stream(("" + complete).split("\n")) + .map(String::trim).filter(x -> x.length() > 0).collect(Collectors.toList()); + String newTableTxt = StringTools.transposeMarkdownTable(Stream.concat(rows.stream(), + newRows.stream()).collect(Collectors.joining("\n")), false, true); + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.replaceString( + editor.getDocument(), + table.getTextRange().getStartOffset(), + table.getTextRange().getEndOffset(), + newTableTxt.replace("\n", "\n" + indent)); + }); } }; } - static String transposeMarkdownTable(String table, boolean inputHeader, boolean outputHeader) { - String[][] cells = parseMarkdownTable(table, inputHeader); - StringBuilder transposedTable = new StringBuilder(); - int columns = cells[0].length; - int rows = cells.length; - if (outputHeader) columns = columns + 1; - for (int column = 0; column < columns; column++) { - transposedTable.append("|"); - for (int row = 0; row < rows; row++) { - String cellValue; - String[] rowCells = cells[row]; - if (outputHeader) { - if (column < 1) { - cellValue = rowCells[column].trim(); - } else if (column == 1) { - cellValue = "---"; - } else if ((column - 1) >= rowCells.length) { - cellValue = ""; - } else { - cellValue = rowCells[column - 1].trim(); - } - } else { - cellValue = rowCells[column].trim(); - } - transposedTable.append(" ").append(cellValue).append(" |"); - } - transposedTable.append("\n"); - } - return transposedTable.toString(); - } - - private static String[][] parseMarkdownTable(String table, boolean removeHeader) { - ArrayList rows = new ArrayList(Arrays.stream(table.split("\n")).map(x -> Arrays.stream(x.split("\\|")).filter(cell -> !cell.isEmpty()).toArray(String[]::new)).collect(Collectors.toList())); - if (removeHeader) { - rows.remove(1); - } - return rows.stream() - //.filter(x -> x.length == rows.get(0).length) - .toArray(String[][]::new); - } - @Nullable - public static AnAction markdownNewTableRowsAction(@NotNull AnActionEvent e) { + protected AnAction markdownNewTableRowsAction(@NotNull AnActionEvent e) { Caret caret = e.getData(CommonDataKeys.CARET); if (null == caret) return null; PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); @@ -480,19 +498,24 @@ public static AnAction markdownNewTableRowsAction(@NotNull AnActionEvent e) { PsiElement table = PsiUtil.getSmallestIntersecting(psiFile, caret.getSelectionStart(), caret.getSelectionEnd(), "MarkdownTableImpl"); if (null == table) return null; if (null != table) { - List rows = trim(PsiUtil.getAll(table, "MarkdownTableRowImpl").stream().map(PsiElement::getText).collect(Collectors.toList()), 10, true); - String n = Integer.toString(rows.size() * 2); + List rows = StringTools.trim(PsiUtil.getAll(table, "MarkdownTableRowImpl") + .stream().map(PsiElement::getText).collect(Collectors.toList()), 10, true); + CharSequence n = Integer.toString(rows.size() * 2); return new AnAction("Add _Table Rows", "Add table rows", null) { @Override public void actionPerformed(@NotNull AnActionEvent event) { AppSettingsState settings = AppSettingsState.getInstance(); - String indent = getIndent(caret); - List newRows = newRows(settings, n, rows, ""); - String itemText = indent + newRows.stream().collect(Collectors.joining("\n" + indent)); - WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { - final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); - editor.getDocument().insertString(table.getTextRange().getEndOffset(), "\n" + itemText); - }); + CharSequence indent = UITools.getIndent(caret); + UITools.redoableRequest(newRowsRequest(settings, n, rows, ""), + "", + event, + (CharSequence complete) -> { + List newRows = Arrays.stream(("" + complete).split("\n")) + .map(String::trim).filter(x -> x.length() > 0).collect(Collectors.toList()); + CharSequence itemText = indent + newRows.stream().collect(Collectors.joining("\n" + indent)); + final Editor editor = event.getRequiredData(CommonDataKeys.EDITOR); + return UITools.insertString(editor.getDocument(), table.getTextRange().getEndOffset(), "\n" + itemText); + }); } }; } @@ -500,26 +523,26 @@ public void actionPerformed(@NotNull AnActionEvent event) { } @NotNull - private static List newRows(AppSettingsState settings, String n, List rows, String rowPrefix) { - String complete; - try { - complete = settings.createTranslationRequest() - .setInstruction(getInstruction(settings.style, "List " + n + " items")) - .setInputType("instruction") - .setInputText("List " + n + " items") - .setOutputType("markdown") - .setOutputAttrute("style", settings.style) - .buildCompletionRequest() - .appendPrompt("\n" + String.join("\n", rows) + "\n" + rowPrefix) - .complete(""); - } catch (IOException | ModerationException ex) { - throw new RuntimeException(ex); - } - return Arrays.stream((rowPrefix + complete).split("\n")).map(String::trim).filter(x -> !x.isEmpty()).collect(Collectors.toList()); + protected CompletionRequest newRowsRequest(AppSettingsState settings, CharSequence n, List rows, CharSequence rowPrefix) { + return settings.createTranslationRequest() + .setInstruction(UITools.getInstruction("List " + n + " items")) + .setInputType("instruction") + .setInputText("List " + n + " items") + .setOutputType("markdown") + .setOutputAttrute("style", settings.style) + .buildCompletionRequest() + .appendPrompt("\n" + String.join("\n", rows) + "\n" + rowPrefix); } + /** + * Creates a {@link TextReplacementAction} for the given {@link AnActionEvent} and human language. + * + * @param e the action event + * @param humanLanguage the human language + * @return the {@link TextReplacementAction} or {@code null} if no action can be created + */ @Nullable - public static TextReplacementAction markdownContextAction(@NotNull AnActionEvent e, String humanLanguage) { + protected AnAction markdownContextAction(@NotNull AnActionEvent e, CharSequence humanLanguage) { Caret caret = e.getData(CommonDataKeys.CARET); if (null != caret) { int selectionStart = caret.getSelectionStart(); @@ -533,15 +556,13 @@ public static TextReplacementAction markdownContextAction(@NotNull AnActionEvent context = context + "\n"; return settings.createTranslationRequest() .setOutputType("markdown") - .setInstruction(getInstruction(settings.style, String.format("Using Markdown and %s", humanLanguage))) + .setInstruction(UITools.getInstruction(String.format("Using Markdown and %s", humanLanguage))) .setInputType("instruction") .setInputText(humanDescription) .setOutputAttrute("type", "document") .setOutputAttrute("style", settings.style) .buildCompletionRequest() - //.addStops(new String[]{"#"}) - .appendPrompt(context) - .complete(getIndent(caret)); + .appendPrompt(context); }); } } @@ -553,7 +574,7 @@ public static void addIfNotNull(@NotNull ArrayList children, AnAction } @Nullable - public static AnAction printTreeAction(@NotNull AnActionEvent e) { + protected AnAction printTreeAction(@NotNull AnActionEvent e) { Caret caret = e.getData(CommonDataKeys.CARET); if (null == caret) return null; PsiElement psiFile = e.getData(CommonDataKeys.PSI_FILE); @@ -573,7 +594,7 @@ public void actionPerformed(@NotNull final AnActionEvent e1) { } @Nullable - public static AnAction rewordCommentAction(@NotNull AnActionEvent e, ComputerLanguage computerLanguage, String humanLanguage) { + protected AnAction rewordCommentAction(@NotNull AnActionEvent e, ComputerLanguage computerLanguage, String humanLanguage) { PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); if (null == psiFile) return null; Caret caret = e.getData(CommonDataKeys.CARET); @@ -587,40 +608,36 @@ public static AnAction rewordCommentAction(@NotNull AnActionEvent e, ComputerLan public void actionPerformed(@NotNull final AnActionEvent e1) { final Editor editor = e1.getRequiredData(CommonDataKeys.EDITOR); AppSettingsState settings = AppSettingsState.getInstance(); - try { - TextBlockFactory commentModel = computerLanguage.getCommentModel(largestIntersectingComment.getText()); - String commentText = commentModel.fromString(largestIntersectingComment.getText().trim()).stream() - .map(String::trim) - .filter(x -> !x.isEmpty()) - .reduce((a, b) -> a + "\n" + b).get(); - String result = settings.createTranslationRequest() - .setInstruction(getInstruction(settings.style, "Reword")) - .setInputText(commentText) - .setInputType(humanLanguage) - .setOutputAttrute("type", "input") - .setOutputType(humanLanguage) - .setOutputAttrute("type", "output") - .setOutputAttrute("style", settings.style) - .buildCompletionRequest() - .complete(""); - String indent = getIndent(caret); - String finalResult = indent + commentModel.fromString(StringTools.lineWrapping(result, 120)).withIndent(indent); - WriteCommandAction.runWriteCommandAction(e1.getProject(), () -> { - editor.getDocument().replaceString( - largestIntersectingComment.getTextRange().getStartOffset(), - largestIntersectingComment.getTextRange().getEndOffset(), - finalResult); - }); - } catch (ModerationException | IOException ex) { - handle(ex); - } + String text = largestIntersectingComment.getText(); + TextBlockFactory commentModel = computerLanguage.getCommentModel(text); + String commentText = commentModel.fromString(text.trim()).stream() + .map(Object::toString) + .map(String::trim) + .filter(x -> !x.isEmpty()) + .reduce((a, b) -> a + "\n" + b).get(); + int startOffset = largestIntersectingComment.getTextRange().getStartOffset(); + int endOffset = largestIntersectingComment.getTextRange().getEndOffset(); + CharSequence indent = UITools.getIndent(caret); + UITools.redoableRequest(settings.createTranslationRequest() + .setInstruction(UITools.getInstruction("Reword")) + .setInputText(commentText) + .setInputType(humanLanguage) + .setOutputAttrute("type", "input") + .setOutputType(humanLanguage) + .setOutputAttrute("type", "output") + .setOutputAttrute("style", settings.style) + .buildCompletionRequest(), "", e1, (CharSequence result) -> { + String lineWrapping = StringTools.lineWrapping(result, 120); + CharSequence finalResult = indent.toString() + commentModel.fromString(lineWrapping).withIndent(indent); + return UITools.replaceString(editor.getDocument(), startOffset, endOffset, finalResult); + }); } }; } @Nullable - public static AnAction psiClassContextAction(@NotNull AnActionEvent e, ComputerLanguage computerLanguage, String humanLanguage) { + protected AnAction psiClassContextAction(@NotNull AnActionEvent e, ComputerLanguage computerLanguage, String humanLanguage) { PsiFile psiFile = e.getData(CommonDataKeys.PSI_FILE); if (null == psiFile) return null; Caret caret = e.getData(CommonDataKeys.CARET); @@ -637,35 +654,32 @@ public void actionPerformed(@NotNull final AnActionEvent e1) { final Caret primaryCaret = caretModel.getPrimaryCaret(); @NotNull String selectedText = primaryCaret.getSelectedText(); AppSettingsState settings = AppSettingsState.getInstance(); - try { - - String instruct = (selectedText.split(" ").length > 4 ? selectedText : largestIntersectingComment.getText()).trim(); - String specification = computerLanguage.getCommentModel(instruct).fromString(instruct).stream() - .map(String::trim) - .filter(x -> !x.isEmpty()) - .reduce((a, b) -> a + " " + b).get(); - String result = settings.createTranslationRequest() - .setInstruction("Implement " + humanLanguage + " as " + computerLanguage.name() + " code") - .setInputType(humanLanguage) - .setInputAttribute("type", "instruction") - .setInputText(specification) - .setOutputType(computerLanguage.name()) - .setOutputAttrute("type", "code") - .setOutputAttrute("style", settings.style) - .buildCompletionRequest() - .appendPrompt(PsiClassContext.getContext(psiFile, selectionStart, selectionEnd) + "\n") - .complete(getIndent(caret)); - WriteCommandAction.runWriteCommandAction(e1.getProject(), () -> { - editor.getDocument().insertString(largestIntersectingComment.getTextRange().getEndOffset(), "\n" + result); - }); - } catch (ModerationException | IOException ex) { - handle(ex); - } - } + String instruct = (selectedText.split(" ").length > 4 ? selectedText : largestIntersectingComment.getText()).trim(); + String specification = computerLanguage.getCommentModel(instruct).fromString(instruct).stream() + .map(Object::toString) + .map(String::trim) + .filter(x -> !x.isEmpty()) + .reduce((a, b) -> a + " " + b).get(); + int endOffset = largestIntersectingComment.getTextRange().getEndOffset(); + UITools.redoableRequest(settings.createTranslationRequest() + .setInstruction("Implement " + humanLanguage + " as " + computerLanguage.name() + " code") + .setInputType(humanLanguage) + .setInputAttribute("type", "instruction") + .setInputText(specification) + .setOutputType(computerLanguage.name()) + .setOutputAttrute("type", "code") + .setOutputAttrute("style", settings.style) + .buildCompletionRequest() + .appendPrompt(PsiClassContext.getContext(psiFile, selectionStart, selectionEnd) + "\n"), + UITools.getIndent(caret), + e1, + (CharSequence result) -> UITools.insertString(editor.getDocument(), endOffset, "\n" + result)); + } }; } + @Override public void update(@NotNull AnActionEvent e) { e.getPresentation().setEnabledAndVisible(true); diff --git a/src/main/java/com/github/simiacryptus/aicoder/TextReplacementAction.java b/src/main/java/com/github/simiacryptus/aicoder/TextReplacementAction.java deleted file mode 100644 index 7dd5fa31..00000000 --- a/src/main/java/com/github/simiacryptus/aicoder/TextReplacementAction.java +++ /dev/null @@ -1,63 +0,0 @@ -package com.github.simiacryptus.aicoder; - -import com.github.simiacryptus.aicoder.openai.ModerationException; -import com.intellij.openapi.actionSystem.AnAction; -import com.intellij.openapi.actionSystem.AnActionEvent; -import com.intellij.openapi.actionSystem.CommonDataKeys; -import com.intellij.openapi.command.WriteCommandAction; -import com.intellij.openapi.editor.Caret; -import com.intellij.openapi.editor.CaretModel; -import com.intellij.openapi.editor.Editor; -import com.intellij.openapi.util.NlsActions; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import javax.swing.*; -import java.io.IOException; - -/** - * TextReplacementAction is an abstract class that extends the AnAction class. - * It provides a static method create() that takes in a text, description, icon, and an ActionTextEditorFunction. - * It also provides an actionPerformed() method that is called when the action is performed. - * This method gets the editor, caret model, and primary caret from the AnActionEvent. - * It then calls the edit() method, which is implemented by the subclasses, and replaces the selected text with the new text. - * The ActionTextEditorFunction is a functional interface that takes in an AnActionEvent and a String and returns a String. - */ -public abstract class TextReplacementAction extends AnAction { - - public TextReplacementAction(@Nullable @NlsActions.ActionText String text, @Nullable @NlsActions.ActionDescription String description, @Nullable Icon icon) { - super(text, description, icon); - } - - public static @NotNull TextReplacementAction create(@Nullable @NlsActions.ActionText String text, @Nullable @NlsActions.ActionDescription String description, @Nullable Icon icon, @NotNull ActionTextEditorFunction fn) { - return new TextReplacementAction(text, description, icon) { - @Override - protected String edit(@NotNull AnActionEvent e, String previousText) throws IOException, ModerationException { - return fn.apply(e, previousText); - } - }; - } - - @Override - public void actionPerformed(@NotNull final AnActionEvent e) { - final Editor editor = e.getRequiredData(CommonDataKeys.EDITOR); - final CaretModel caretModel = editor.getCaretModel(); - final Caret primaryCaret = caretModel.getPrimaryCaret(); - final String newText; - try { - newText = edit(e, primaryCaret.getSelectedText()); - WriteCommandAction.runWriteCommandAction(e.getProject(), () -> { - editor.getDocument().replaceString(primaryCaret.getSelectionStart(), primaryCaret.getSelectionEnd(), newText); - }); - } catch (ModerationException | IOException ex) { - EditorMenu.handle(ex); - } - } - - protected abstract String edit(@NotNull AnActionEvent e, String previousText) throws IOException, ModerationException; - - public interface ActionTextEditorFunction { - String apply(AnActionEvent actionEvent, String input) throws IOException, ModerationException; - } - -} diff --git a/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsComponent.java b/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsComponent.java index 5b073817..105405a5 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsComponent.java +++ b/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsComponent.java @@ -1,9 +1,8 @@ package com.github.simiacryptus.aicoder.config; import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.github.simiacryptus.aicoder.StyleUtil; -import com.github.simiacryptus.aicoder.openai.OpenAI; +import com.github.simiacryptus.aicoder.util.StyleUtil; +import com.github.simiacryptus.aicoder.openai.OpenAI_API; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.ui.ComboBox; import com.intellij.ui.components.JBCheckBox; @@ -18,6 +17,7 @@ public class AppSettingsComponent extends SimpleSettingsComponent { private static final Logger log = Logger.getInstance(AppSettingsComponent.class); + @Name("API Base") public final JBTextField apiBase = new JBTextField(); @Name("API Key") @@ -27,19 +27,26 @@ public class AppSettingsComponent extends SimpleSettingsComponent 0) { + try { + ComboBox comboBox = new ComboBox<>(new CharSequence[]{settings.model}); + OpenAI_API.onSuccess(OpenAI_API.INSTANCE.getEngines(), engines -> { + JsonNode data = engines.get("data"); + CharSequence[] items = new CharSequence[data.size()]; + for (int i = 0; i < data.size(); i++) { + items[i] = data.get(i).get("id").asText(); + } + Arrays.sort(items); + Arrays.stream(items).forEach(comboBox::addItem); + }); + return comboBox; + } catch (Throwable e) { + log.warn(e); } - Arrays.sort(items); - return new ComboBox(items); - } catch (Throwable e) { - log.warn(e); - return new JBTextField(); } + return new JBTextField(); } @Name("Style") @@ -66,33 +73,28 @@ public void actionPerformed(ActionEvent e) { StyleUtil.demoStyle(style.getText()); } }); + @Name("Token Counter") + public final JBTextField tokenCounter = new JBTextField(); + public final JButton clearCounter = new JButton(new AbstractAction("Clear Token Counter") { + @Override + public void actionPerformed(ActionEvent e) { + tokenCounter.setText("0"); + } + }); @Name("Developer Tools") public final JBCheckBox devActions = new JBCheckBox(); @Name("API Log Level") - public final ComboBox apiLogLevel = new ComboBox(Arrays.stream(LogLevel.values()).map(x->x.name()).toArray(String[]::new)); + public final ComboBox apiLogLevel = new ComboBox(Arrays.stream(LogLevel.values()).map(x -> x.name()).toArray(CharSequence[]::new)); // @Name("API Envelope") // public final ComboBox translationRequestTemplate = new ComboBox(Arrays.stream(TranslationRequestTemplate.values()).map(x->x.name()).toArray(String[]::new)); - public @NotNull JComponent getPreferredFocusedComponent() { - return apiKey; + public AppSettingsComponent() { + tokenCounter.setEditable(false); } - public static String queryAPIKey() { - JPanel panel = new JPanel(); - JLabel label = new JLabel("Enter OpenAI API Key:"); - JPasswordField pass = new JPasswordField(100); - panel.add(label); - panel.add(pass); - String[] options = new String[]{"OK", "Cancel"}; - int option = JOptionPane.showOptionDialog(null, panel, "API Key", - JOptionPane.NO_OPTION, JOptionPane.PLAIN_MESSAGE, - null, options, options[1]); - if (option == 0) { - char[] password = pass.getPassword(); - return new String(password); - } - return null; + public @NotNull JComponent getPreferredFocusedComponent() { + return apiKey; } } \ No newline at end of file diff --git a/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsState.java b/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsState.java index eb93fc92..2df7891b 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsState.java +++ b/src/main/java/com/github/simiacryptus/aicoder/config/AppSettingsState.java @@ -1,5 +1,6 @@ package com.github.simiacryptus.aicoder.config; +import com.github.simiacryptus.aicoder.openai.CompletionRequest; import com.github.simiacryptus.aicoder.openai.translate.TranslationRequest; import com.github.simiacryptus.aicoder.openai.translate.TranslationRequestTemplate; import com.intellij.openapi.application.ApplicationManager; @@ -31,6 +32,7 @@ public class AppSettingsState implements PersistentStateComponent mostUsedHistory = new HashMap<>(); private @NotNull List mostRecentHistory = new ArrayList<>(); public int historyLimit = 10; @@ -51,6 +53,16 @@ public TranslationRequest createTranslationRequest() { return translationRequestTemplate.get(this); } + public CompletionRequest createCompletionRequest() { + return new CompletionRequest( + "", + temperature, + maxTokens, + null, + true + ); + } + @Nullable @Override public AppSettingsState getState() { @@ -85,15 +97,15 @@ public int hashCode() { return Objects.hash(apiBase, apiKey, model, maxTokens, temperature, translationRequestTemplate, apiLogLevel, devActions, style); } - public void addInstructionToHistory(String instruction) { + public void addInstructionToHistory(CharSequence instruction) { synchronized (mostRecentHistory) { - mostRecentHistory.add(instruction); + mostRecentHistory.add(instruction.toString()); while(mostRecentHistory.size() > historyLimit) { mostRecentHistory.remove(0); } } synchronized (mostUsedHistory) { - mostUsedHistory.put(instruction, mostUsedHistory.getOrDefault(instruction, 0) + 1); + mostUsedHistory.put(instruction.toString(), mostUsedHistory.getOrDefault(instruction, 0) + 1); } // If the instruction history is bigger than the history limit, @@ -106,10 +118,10 @@ public void addInstructionToHistory(String instruction) { // Then we'll remove all the ones we don't want to keep, // And that's how we'll make sure the instruction history is neat! if (mostUsedHistory.size() > historyLimit) { - List retain = mostUsedHistory.entrySet().stream() + List retain = mostUsedHistory.entrySet().stream() .sorted(Map.Entry.comparingByValue().reversed()) .limit(historyLimit).map(Map.Entry::getKey).collect(Collectors.toList()); - HashSet toRemove = new HashSet<>(mostUsedHistory.keySet()); + HashSet toRemove = new HashSet<>(mostUsedHistory.keySet()); toRemove.removeAll(retain); toRemove.removeAll(mostRecentHistory); toRemove.forEach(mostUsedHistory::remove); diff --git a/src/main/java/com/github/simiacryptus/aicoder/config/SimpleSettingsComponent.java b/src/main/java/com/github/simiacryptus/aicoder/config/SimpleSettingsComponent.java index adc6de1b..fd3f283b 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/config/SimpleSettingsComponent.java +++ b/src/main/java/com/github/simiacryptus/aicoder/config/SimpleSettingsComponent.java @@ -10,6 +10,7 @@ import javax.swing.*; import javax.swing.text.JTextComponent; import java.lang.reflect.Field; +import java.lang.reflect.Modifier; public class SimpleSettingsComponent { private static final Logger log = Logger.getInstance(SimpleSettingsComponent.class); @@ -24,6 +25,7 @@ public class SimpleSettingsComponent { private JPanel buildMainPanel() { FormBuilder formBuilder = FormBuilder.createFormBuilder(); for (Field field : this.getClass().getDeclaredFields()) { + if(Modifier.isStatic(field.getModifiers())) continue; try { field.setAccessible(true); Name nameAnnotation = field.getDeclaredAnnotation(Name.class); @@ -56,7 +58,7 @@ public void getProperties(@NotNull T settings) { if (uiFieldVal instanceof JTextComponent) { newSettingsValue = ((JTextComponent) uiFieldVal).getText(); } else if (uiFieldVal instanceof ComboBox) { - newSettingsValue = ((ComboBox) uiFieldVal).getItem(); + newSettingsValue = ((ComboBox) uiFieldVal).getItem(); } break; case "int": @@ -64,6 +66,11 @@ public void getProperties(@NotNull T settings) { newSettingsValue = Integer.parseInt(((JTextComponent) uiFieldVal).getText()); } break; + case "long": + if (uiFieldVal instanceof JTextComponent) { + newSettingsValue = Long.parseLong(((JTextComponent) uiFieldVal).getText()); + } + break; case "double": if (uiFieldVal instanceof JTextComponent) { newSettingsValue = Double.parseDouble(((JTextComponent) uiFieldVal).getText()); @@ -80,9 +87,9 @@ public void getProperties(@NotNull T settings) { if (java.lang.Enum.class.isAssignableFrom(settingsField.getType())) { if (uiFieldVal instanceof ComboBox) { - ComboBox comboBox = (ComboBox) uiFieldVal; - String item = comboBox.getItem(); - newSettingsValue = Enum.valueOf((Class) settingsField.getType(), item); + ComboBox comboBox = (ComboBox) uiFieldVal; + CharSequence item = comboBox.getItem(); + newSettingsValue = Enum.valueOf((Class) settingsField.getType(), item.toString()); } } break; @@ -108,7 +115,7 @@ public void setProperties(@NotNull T settings) { if (uiVal instanceof JTextComponent) { ((JTextComponent) uiVal).setText((String) settingsVal); } else if (uiVal instanceof ComboBox) { - ((ComboBox) uiVal).setItem(settingsVal.toString()); + ((ComboBox) uiVal).setItem(settingsVal.toString()); } break; case "int": @@ -116,6 +123,11 @@ public void setProperties(@NotNull T settings) { ((JTextComponent) uiVal).setText(Integer.toString((Integer) settingsVal)); } break; + case "long": + if (uiVal instanceof JTextComponent) { + ((JTextComponent) uiVal).setText(Long.toString((Integer) settingsVal)); + } + break; case "double": if (uiVal instanceof JTextComponent) { ((JTextComponent) uiVal).setText(Double.toString(((Double) settingsVal))); @@ -130,7 +142,7 @@ public void setProperties(@NotNull T settings) { break; default: if (uiVal instanceof ComboBox) { - ((ComboBox) uiVal).setItem(settingsVal.toString()); + ((ComboBox) uiVal).setItem(settingsVal.toString()); } break; } diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionRequest.java b/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionRequest.java index 497ea17a..4ee31536 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionRequest.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionRequest.java @@ -1,31 +1,31 @@ package com.github.simiacryptus.aicoder.openai; -import com.github.simiacryptus.aicoder.text.IndentedText; +import com.github.simiacryptus.aicoder.util.IndentedText; +import com.google.common.util.concurrent.ListenableFuture; +import com.intellij.openapi.project.Project; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Objects; -import static com.github.simiacryptus.aicoder.text.StringTools.stripPrefix; -import static com.github.simiacryptus.aicoder.text.StringTools.stripUnbalancedTerminators; +import static com.github.simiacryptus.aicoder.util.StringTools.stripPrefix; +import static com.github.simiacryptus.aicoder.util.StringTools.stripUnbalancedTerminators; /** * The CompletionRequest class is used to create a request for completion of a given prompt. */ public class CompletionRequest { public String prompt; + public String suffix = null; public double temperature; public int max_tokens; - public String[] stop; + public CharSequence[] stop; public Integer logprobs; public boolean echo; - public CompletionRequest() { - } - - public CompletionRequest(String prompt, double temperature, int max_tokens, Integer logprobs, boolean echo, String... stop) { + public CompletionRequest(String prompt, double temperature, int max_tokens, Integer logprobs, boolean echo, CharSequence... stop) { this.prompt = prompt; this.temperature = temperature; this.max_tokens = max_tokens; @@ -35,10 +35,10 @@ public CompletionRequest(String prompt, double temperature, int max_tokens, Inte } @NotNull - public String complete(String indent) throws IOException, ModerationException { - CompletionResponse response = OpenAI.INSTANCE.complete(this); - return response + public ListenableFuture complete(@Nullable Project project, CharSequence indent) { + return OpenAI_API.map(OpenAI_API.INSTANCE.complete(project, this), response -> response .getFirstChoice() + .map(Objects::toString) .map(String::trim) .map(completion -> stripPrefix(completion, this.prompt.trim())) .map(String::trim) @@ -47,27 +47,32 @@ public String complete(String indent) throws IOException, ModerationException { .map(indentedText -> indentedText.withIndent(indent)) .map(IndentedText::toString) .map(indentedText -> indent + indentedText) - .orElse(""); + .orElse("")); } - public @NotNull CompletionRequest appendPrompt(String prompt) { + public @NotNull CompletionRequest appendPrompt(CharSequence prompt) { this.prompt = this.prompt + prompt; return this; } - public @NotNull CompletionRequest addStops(@NotNull String... newStops) { - ArrayList stops = new ArrayList<>(); - for (String x : newStops) { + public @NotNull CompletionRequest addStops(@NotNull CharSequence... newStops) { + ArrayList stops = new ArrayList<>(); + for (CharSequence x : newStops) { if (x != null) { - if (!x.isEmpty()) { + if (x.length() > 0) { stops.add(x); } } } if (!stops.isEmpty()) { - Arrays.stream(this.stop).forEach(stops::add); - this.stop = stops.stream().distinct().toArray(String[]::new); + if(null != this.stop) Arrays.stream(this.stop).forEach(stops::add); + this.stop = stops.stream().distinct().toArray(CharSequence[]::new); } return this; } + + public CompletionRequest setSuffix(CharSequence suffix) { + this.suffix = suffix.toString(); + return this; + } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionResponse.java b/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionResponse.java index 7b7f1fa4..449ebfee 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionResponse.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/CompletionResponse.java @@ -27,7 +27,7 @@ public CompletionResponse(String id, String object, int created, String model, C this.error = error; } - public @NotNull Optional getFirstChoice() { + public @NotNull Optional getFirstChoice() { return Optional.ofNullable(this.choices).flatMap(choices -> Arrays.stream(choices).findFirst()).map(choice -> choice.text.trim()); } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/LogProbs.java b/src/main/java/com/github/simiacryptus/aicoder/openai/LogProbs.java index 30efa968..4be2e717 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/LogProbs.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/LogProbs.java @@ -3,7 +3,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; public class LogProbs { - public String[] tokens; + public CharSequence[] tokens; public double[] token_logprobs; public ObjectNode[] top_logprobs; public int[] text_offset; @@ -11,7 +11,7 @@ public class LogProbs { public LogProbs() { } - public LogProbs(String[] tokens, double[] token_logprobs, ObjectNode[] top_logprobs, int[] text_offset) { + public LogProbs(CharSequence[] tokens, double[] token_logprobs, ObjectNode[] top_logprobs, int[] text_offset) { this.tokens = tokens; this.token_logprobs = token_logprobs; this.top_logprobs = top_logprobs; diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI.java b/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI.java deleted file mode 100644 index c6506082..00000000 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI.java +++ /dev/null @@ -1,185 +0,0 @@ -package com.github.simiacryptus.aicoder.openai; - -import com.fasterxml.jackson.databind.MapperFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.github.simiacryptus.aicoder.config.AppSettingsComponent; -import com.github.simiacryptus.aicoder.config.AppSettingsState; -import com.google.gson.Gson; -import com.google.gson.JsonObject; -import com.intellij.openapi.diagnostic.Logger; -import com.jetbrains.rd.util.LogLevel; -import org.apache.http.HttpEntity; -import org.apache.http.HttpResponse; -import org.apache.http.client.methods.HttpGet; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.methods.HttpRequestBase; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.HttpClientBuilder; -import org.apache.http.util.EntityUtils; -import org.jetbrains.annotations.NotNull; - -import java.io.IOException; -import java.util.Map; - -import static com.github.simiacryptus.aicoder.text.StringTools.stripPrefix; - -public class OpenAI { - - private static final Logger log = Logger.getInstance(OpenAI.class); - - public static final OpenAI INSTANCE = new OpenAI(); - private transient AppSettingsState settings = null; - - protected AppSettingsState getSettingsState() { - if (null == this.settings) { - this.settings = AppSettingsState.getInstance(); - } - return settings; - } - - public ObjectNode getEngines() throws IOException { - return getMapper().readValue(get(getSettingsState().apiBase + "/engines"), ObjectNode.class); - } - - protected String post(String url, @NotNull String body) throws IOException, InterruptedException { - return post(url, body, 3); - } - - public CompletionResponse complete(@NotNull CompletionRequest completionRequest) throws IOException, ModerationException { - try { - AppSettingsState settings = getSettingsState(); - if (completionRequest.prompt.length() > settings.maxPrompt) - throw new IOException("Prompt too long:" + completionRequest.prompt.length() + " chars"); - moderate(completionRequest.prompt); - String request = getMapper().writeValueAsString(completionRequest); - String result = post(settings.apiBase + "/engines/" + settings.model + "/completions", request); - JsonObject jsonObject = new Gson().fromJson(result, JsonObject.class); - if (jsonObject.has("error")) { - JsonObject errorObject = jsonObject.getAsJsonObject("error"); - String errorMessage = errorObject.get("message").getAsString(); - log.error(errorMessage); - throw new IOException(errorMessage); - } - CompletionResponse completionResponse = getMapper().readValue(result, CompletionResponse.class); - String completionResult = stripPrefix(completionResponse.getFirstChoice().orElse("").trim(), completionRequest.prompt.trim()); - log(settings.apiLogLevel, String.format("Text Completion Request\nPrefix:\n\t%s\n\nCompletion:\n\t%s", completionRequest.prompt.replace("\n", "\n\t"), completionResult.replace("\n", "\n\t"))); - return completionResponse; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - } - - private void log(LogLevel level, String msg) { - String message = msg.trim().replace("\n", "\n\t"); - switch (level) { - case Error: - log.error(message); - break; - case Warn: - log.warn(message); - break; - case Info: - log.info(message); - break; - default: - log.debug(message); - break; - } - } - - public void moderate(@NotNull String text) throws IOException, InterruptedException, ModerationException { - String body = getMapper().writeValueAsString(Map.of("input", text)); - AppSettingsState settings = getSettingsState(); - String result = post(settings.apiBase + "/moderations", body); - JsonObject jsonObject = new Gson().fromJson(result, JsonObject.class); - if (jsonObject.has("error")) { - JsonObject errorObject = jsonObject.getAsJsonObject("error"); - throw new IOException(errorObject.get("message").getAsString()); - } - JsonObject moderationResult = jsonObject.getAsJsonArray("results").get(0).getAsJsonObject(); - log(LogLevel.Debug, String.format("Moderation Request\nText:\n%s\n\nResult:\n%s", text.replace("\n", "\n\t"), result)); - if (moderationResult.get("flagged").getAsBoolean()) { - JsonObject categoriesObj = moderationResult.get("categories").getAsJsonObject(); - throw new ModerationException("Moderation flagged this request due to " + categoriesObj.keySet().stream().filter(c -> categoriesObj.get(c).getAsBoolean()).reduce((a, b) -> a + ", " + b).orElse("???")); - } - } - - /** - * Posts a request to the given URL with the given JSON body and retries if an IOException is thrown. - * - * @param url The URL to post the request to. - * @param json The JSON body of the request. - * @param retries The number of times to retry the request if an IOException is thrown. - * @return The response from the request. - * @throws IOException If an IOException is thrown and the number of retries is exceeded. - * @throws InterruptedException If the thread is interrupted while sleeping. - */ - protected String post(String url, @NotNull String json, int retries) throws IOException, InterruptedException { - try { - HttpClientBuilder client = HttpClientBuilder.create(); - HttpPost request = new HttpPost(url); - request.addHeader("Content-Type", "application/json"); - request.addHeader("Accept", "application/json"); - authorize(request); - request.setEntity(new StringEntity(json)); - HttpResponse response = client.build().execute(request); - HttpEntity entity = response.getEntity(); - return EntityUtils.toString(entity); - } catch (IOException e) { - if (retries > 0) { - e.printStackTrace(); - Thread.sleep(15000); - return post(url, json, retries - 1); - } - throw e; - } - } - - protected @NotNull ObjectMapper getMapper() { - ObjectMapper mapper = new ObjectMapper(); - mapper - .enable(SerializationFeature.INDENT_OUTPUT) - .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) - .enable(MapperFeature.USE_STD_BEAN_NAMING) - //.registerModule(DefaultScalaModule) - .activateDefaultTyping(mapper.getPolymorphicTypeValidator()); - return mapper; - } - - protected void authorize(@NotNull HttpRequestBase request) throws IOException { - AppSettingsState settingsState = getSettingsState(); - String apiKey = settingsState.apiKey; - if (apiKey == null || apiKey.isEmpty()) { - synchronized (settingsState) { - apiKey = settingsState.apiKey; - if (apiKey == null || apiKey.isEmpty()) { - apiKey = AppSettingsComponent.queryAPIKey(); - settingsState.apiKey = apiKey; - } - } - } - request.addHeader("Authorization", "Bearer " + apiKey); - } - - /** - * Gets the response from the given URL. - * - * @param url The URL to GET the response from. - * @return The response from the given URL. - * @throws IOException If an I/O error occurs. - */ - public String get(String url) throws IOException { - HttpClientBuilder client = HttpClientBuilder.create(); - HttpGet request = new HttpGet(url); - request.addHeader("Content-Type", "application/json"); - request.addHeader("Accept", "application/json"); - authorize(request); - HttpResponse response = client.build().execute(request); - HttpEntity entity = response.getEntity(); - return EntityUtils.toString(entity); - } - -} \ No newline at end of file diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI_API.java b/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI_API.java new file mode 100644 index 00000000..ea82e63f --- /dev/null +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/OpenAI_API.java @@ -0,0 +1,290 @@ +package com.github.simiacryptus.aicoder.openai; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.simiacryptus.aicoder.config.AppSettingsState; +import com.github.simiacryptus.aicoder.util.UITools; +import com.google.common.base.Function; +import com.google.common.util.concurrent.*; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.progress.Task; +import com.intellij.openapi.progress.util.AbstractProgressIndicatorBase; +import com.intellij.openapi.project.Project; +import com.jetbrains.rd.util.LogLevel; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.util.EntityUtils; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.function.Consumer; + +import static com.github.simiacryptus.aicoder.util.StringTools.stripPrefix; + +public final class OpenAI_API { + + private static final Logger log = Logger.getInstance(OpenAI_API.class); + + public static final OpenAI_API INSTANCE = new OpenAI_API(); + public final ListeningExecutorService pool; + private transient AppSettingsState settings = null; + + protected AppSettingsState getSettingsState() { + if (null == this.settings) { + this.settings = AppSettingsState.getInstance(); + } + return settings; + } + + protected OpenAI_API() { + this.pool = MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()); + } + + @NotNull + public ListenableFuture getEngines() { + return pool.submit(() -> getMapper().readValue(get(getSettingsState().apiBase + "/engines"), ObjectNode.class)); + } + + protected String post(String url, @NotNull String body) throws IOException, InterruptedException { + return post(url, body, 3); + } + + public ListenableFuture complete(@Nullable Project project, @NotNull CompletionRequest completionRequest) { + AppSettingsState settings = getSettingsState(); + if (null != completionRequest.suffix) { + if (completionRequest.suffix.trim().length() == 0) { + completionRequest.setSuffix(null); + } else { + completionRequest.echo = false; + } + } + if (null != completionRequest.stop && completionRequest.stop.length == 0) { + completionRequest.stop = null; + } + if (completionRequest.prompt.length() > settings.maxPrompt) + throw new IllegalArgumentException("Prompt too long:" + completionRequest.prompt.length() + " chars"); + return OpenAI_API.map(moderateAsync(project, completionRequest.prompt), x -> { + try { + Task.WithResult task = new Task.WithResult<>(project, "OpenAI Text Completion", false) { + @Override + protected CompletionResponse compute(@NotNull ProgressIndicator indicator) throws Exception { + try { + String request = getMapper().writeValueAsString(completionRequest); + String result = post(settings.apiBase + "/engines/" + settings.model + "/completions", request); + JsonObject jsonObject = new Gson().fromJson(result, JsonObject.class); + if (jsonObject.has("error")) { + JsonObject errorObject = jsonObject.getAsJsonObject("error"); + String errorMessage = errorObject.get("message").getAsString(); + log.error(errorMessage); + throw new IOException(errorMessage); + } + CompletionResponse completionResponse = getMapper().readValue(result, CompletionResponse.class); + if (completionResponse.usage != null) { + settings.tokenCounter += completionResponse.usage.total_tokens; + } + String completionResult = stripPrefix(completionResponse.getFirstChoice().orElse("").toString().trim(), completionRequest.prompt.trim()); + if (completionRequest.suffix == null) { + log(settings.apiLogLevel, String.format("Text Completion Request\nPrefix:\n\t%s\nCompletion:\n\t%s", + completionRequest.prompt.replace("\n", "\n\t"), + completionResult.replace("\n", "\n\t"))); + } else { + log(settings.apiLogLevel, String.format("Text Completion Request\nPrefix:\n\t%s\nSuffix:\n\t%s\nCompletion:\n\t%s", + completionRequest.prompt.replace("\n", "\n\t"), + completionRequest.suffix.replace("\n", "\n\t"), + completionResult.replace("\n", "\n\t"))); + } + return completionResponse; + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } + }; + if (null != project) { + return ProgressManager.getInstance().run(task); + } else { + task.run(new AbstractProgressIndicatorBase()); + return task.getResult(); + } + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + public static + ListenableFuture map(ListenableFuture moderateAsync, Function o) { + return Futures.transform(moderateAsync, o, INSTANCE.pool); + } + + public static void onSuccess(ListenableFuture moderateAsync, Consumer o) { + Futures.addCallback(moderateAsync, new FutureCallback() { + @Override + public void onSuccess(I result) { + o.accept(result); + } + + @Override + public void onFailure(Throwable t) { + UITools.handle(t); + } + }, INSTANCE.pool); + } + + private void log(LogLevel level, String msg) { + String message = msg.trim().replace("\n", "\n\t"); + switch (level) { + case Error: + log.error(message); + break; + case Warn: + log.warn(message); + break; + case Info: + log.info(message); + break; + default: + log.debug(message); + break; + } + } + + @NotNull + private ListenableFuture moderateAsync(@Nullable Project project, @NotNull String text) { + Task.WithResult, Exception> task = new Task.WithResult, Exception>(project, "OpenAI Moderation", false) { + @Override + protected ListenableFuture compute(@NotNull ProgressIndicator indicator) throws Exception { + return pool.submit(() -> { + String body = null; + try { + body = getMapper().writeValueAsString(Map.of("input", text)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + AppSettingsState settings1 = getSettingsState(); + String result = null; + try { + result = post(settings1.apiBase + "/moderations", body); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + JsonObject jsonObject = new Gson().fromJson(result, JsonObject.class); + if (jsonObject.has("error")) { + JsonObject errorObject = jsonObject.getAsJsonObject("error"); + throw new RuntimeException(new IOException(errorObject.get("message").getAsString())); + } + JsonObject moderationResult = jsonObject.getAsJsonArray("results").get(0).getAsJsonObject(); + log(LogLevel.Debug, String.format("Moderation Request\nText:\n%s\n\nResult:\n%s", text.replace("\n", "\n\t"), result)); + if (moderationResult.get("flagged").getAsBoolean()) { + JsonObject categoriesObj = moderationResult.get("categories").getAsJsonObject(); + throw new RuntimeException(new ModerationException("Moderation flagged this request due to " + categoriesObj.keySet().stream().filter(c -> categoriesObj.get(c).getAsBoolean()).reduce((a, b) -> a + ", " + b).orElse("???"))); + } + }); + } + }; + try { + if (null != project) { + return ProgressManager.getInstance().run(task); + } else { + task.run(new AbstractProgressIndicatorBase()); + return task.getResult(); + } + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Posts a request to the given URL with the given JSON body and retries if an IOException is thrown. + * + * @param url The URL to post the request to. + * @param json The JSON body of the request. + * @param retries The number of times to retry the request if an IOException is thrown. + * @return The response from the request. + * @throws IOException If an IOException is thrown and the number of retries is exceeded. + * @throws InterruptedException If the thread is interrupted while sleeping. + */ + protected String post(String url, @NotNull String json, int retries) throws IOException, InterruptedException { + try { + HttpClientBuilder client = HttpClientBuilder.create(); + HttpPost request = new HttpPost(url); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Accept", "application/json"); + authorize(request); + request.setEntity(new StringEntity(json)); + HttpResponse response = client.build().execute(request); + HttpEntity entity = response.getEntity(); + return EntityUtils.toString(entity); + } catch (IOException e) { + if (retries > 0) { + e.printStackTrace(); + Thread.sleep(15000); + return post(url, json, retries - 1); + } + throw e; + } + } + + protected @NotNull ObjectMapper getMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper + .enable(SerializationFeature.INDENT_OUTPUT) + .enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS) + .enable(MapperFeature.USE_STD_BEAN_NAMING) + //.registerModule(DefaultScalaModule) + .activateDefaultTyping(mapper.getPolymorphicTypeValidator()); + return mapper; + } + + protected void authorize(@NotNull HttpRequestBase request) throws IOException { + AppSettingsState settingsState = getSettingsState(); + String apiKey = settingsState.apiKey; + if (apiKey == null || apiKey.length() == 0) { + synchronized (settingsState) { + apiKey = settingsState.apiKey; + if (apiKey == null || apiKey.length() == 0) { + apiKey = UITools.queryAPIKey(); + settingsState.apiKey = apiKey; + } + } + } + request.addHeader("Authorization", "Bearer " + apiKey); + } + + /** + * Gets the response from the given URL. + * + * @param url The URL to GET the response from. + * @return The response from the given URL. + * @throws IOException If an I/O error occurs. + */ + public String get(String url) throws IOException { + HttpClientBuilder client = HttpClientBuilder.create(); + HttpGet request = new HttpGet(url); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Accept", "application/json"); + authorize(request); + HttpResponse response = client.build().execute(request); + HttpEntity entity = response.getEntity(); + return EntityUtils.toString(entity); + } + +} \ No newline at end of file diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/BaseTranslationRequest.java b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/BaseTranslationRequest.java index 9c853ac0..ae52505d 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/BaseTranslationRequest.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/BaseTranslationRequest.java @@ -8,47 +8,47 @@ public abstract class BaseTranslationRequest> implements TranslationRequest { - private String inputTag; - private String outputTag; - private String instruction; + private CharSequence inputTag; + private CharSequence outputTag; + private CharSequence instruction; @NotNull - private Map inputAttr = new HashMap<>(); + private Map inputAttr = new HashMap<>(); @NotNull - private Map outputAttr = new HashMap<>(); - private String originalText; + private Map outputAttr = new HashMap<>(); + private CharSequence originalText; private double temperature; private int maxTokens; @Override public String getInputTag() { - return inputTag; + return inputTag.toString(); } @Override public String getOutputTag() { - return outputTag; + return outputTag.toString(); } @Override - public String getInstruction() { - return instruction; + public CharSequence getInstruction() { + return instruction.toString(); } @Override @NotNull - public Map getInputAttr() { + public Map getInputAttr() { return Collections.unmodifiableMap(inputAttr); } @Override @NotNull - public Map getOutputAttr() { + public Map getOutputAttr() { return Collections.unmodifiableMap(outputAttr); } @Override public String getOriginalText() { - return originalText; + return originalText.toString(); } @Override @@ -62,25 +62,25 @@ public int getMaxTokens() { } @Override - public T setInputType(String inputTag) { + public T setInputType(CharSequence inputTag) { this.inputTag = inputTag; return (T) this; } @Override - public T setOutputType(String outputTag) { + public T setOutputType(CharSequence outputTag) { this.outputTag = outputTag; return (T) this; } @Override - public T setInstruction(String instruction) { + public T setInstruction(CharSequence instruction) { this.instruction = instruction; return (T) this; } @Override - public T setInputText(String originalText) { + public T setInputText(CharSequence originalText) { this.originalText = originalText; return (T) this; } @@ -98,8 +98,8 @@ public T setMaxTokens(int maxTokens) { } @Override - public T setInputAttribute(String key, String value) { - if(null == value || value.isEmpty()) { + public T setInputAttribute(CharSequence key, CharSequence value) { + if(null == value || value.length()==0) { inputAttr.remove(key); } else { inputAttr.put(key, value); @@ -108,8 +108,8 @@ public T setInputAttribute(String key, String value) { } @Override - public T setOutputAttrute(String key, String value) { - if(null == value || value.isEmpty()) { + public T setOutputAttrute(CharSequence key, CharSequence value) { + if(null == value || value.length()==0) { outputAttr.remove(key); } else { outputAttr.put(key, value); diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.java b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.java index 7899f4b6..c25dc7cc 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest.java @@ -12,11 +12,11 @@ public interface TranslationRequest { String getOutputTag(); - String getInstruction(); + CharSequence getInstruction(); - @NotNull Map getInputAttr(); + @NotNull Map getInputAttr(); - @NotNull Map getOutputAttr(); + @NotNull Map getOutputAttr(); String getOriginalText(); @@ -24,16 +24,16 @@ public interface TranslationRequest { int getMaxTokens(); - TranslationRequest setInputType(String inputTag); + TranslationRequest setInputType(CharSequence inputTag); - TranslationRequest setOutputType(String outputTag); + TranslationRequest setOutputType(CharSequence outputTag); - TranslationRequest setInstruction(String instruction); + TranslationRequest setInstruction(CharSequence instruction); - TranslationRequest setInputAttribute(String key, String value); - TranslationRequest setOutputAttrute(String key, String value); + TranslationRequest setInputAttribute(CharSequence key, CharSequence value); + TranslationRequest setOutputAttrute(CharSequence key, CharSequence value); - TranslationRequest setInputText(String originalText); + TranslationRequest setInputText(CharSequence originalText); TranslationRequest setTemperature(double temperature); diff --git a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.java b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.java index 98b25593..58334ecc 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.java +++ b/src/main/java/com/github/simiacryptus/aicoder/openai/translate/TranslationRequest_XML.java @@ -16,8 +16,8 @@ public TranslationRequest_XML(AppSettingsState settings) { @Override @NotNull public CompletionRequest buildCompletionRequest() { - String inputAttrStr = getInputAttr().isEmpty() ? "" : (" " + getInputAttr().entrySet().stream().map(t -> String.format("%s=\"%s\"", t.getKey(), t.getValue())).collect(Collectors.joining(" "))); - String outputAttrStr = getOutputAttr().isEmpty() ? "" : (" " + getOutputAttr().entrySet().stream().map(t1 -> String.format("%s=\"%s\"", t1.getKey(), t1.getValue())).collect(Collectors.joining(" "))); + CharSequence inputAttrStr = getInputAttr().isEmpty() ? "" : (" " + getInputAttr().entrySet().stream().map(t -> String.format("%s=\"%s\"", t.getKey(), t.getValue())).collect(Collectors.joining(" "))); + CharSequence outputAttrStr = getOutputAttr().isEmpty() ? "" : (" " + getOutputAttr().entrySet().stream().map(t1 -> String.format("%s=\"%s\"", t1.getKey(), t1.getValue())).collect(Collectors.joining(" "))); String text = String.format("\n<%s%s>%s\n<%s%s>", getInstruction(), getInputTag().toLowerCase(), diff --git a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiClassContext.java b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiClassContext.java index 26d9f2ad..049b81bc 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiClassContext.java +++ b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiClassContext.java @@ -1,6 +1,6 @@ package com.github.simiacryptus.aicoder.psi; -import com.github.simiacryptus.aicoder.text.StringTools; +import com.github.simiacryptus.aicoder.util.StringTools; import com.intellij.openapi.util.TextRange; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElementVisitor; @@ -57,7 +57,7 @@ public void visitElement(@NotNull PsiElement element) { (textRangeStartOffset <= selectionStart && textRangeEndOffset >= selectionStart) || (textRangeStartOffset <= selectionEnd && textRangeEndOffset >= selectionEnd); // Check if the element is within the selection boolean within = (textRangeStartOffset <= selectionStart && textRangeEndOffset > selectionStart) && (textRangeStartOffset <= selectionEnd && textRangeEndOffset > selectionEnd); - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); if (simpleName.equals("PsiImportListImpl")) { currentContext.children.add(new PsiClassContext(text.trim(), isPrior, isOverlap)); } else if (simpleName.equals("PsiCommentImpl") || simpleName.equals("PsiDocCommentImpl")) { diff --git a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiMarkdownContext.java b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiMarkdownContext.java index adedbae2..3f111fbe 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiMarkdownContext.java +++ b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiMarkdownContext.java @@ -37,7 +37,7 @@ public int headerLevel() { public @NotNull PsiMarkdownContext init(@NotNull PsiFile psiFile, int selectionStart, int selectionEnd) { AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { - final String indent = ""; + final CharSequence indent = ""; @NotNull PsiMarkdownContext section = PsiMarkdownContext.this; @Override @@ -54,7 +54,7 @@ public void visitElement(@NotNull PsiElement element) { // Check if the element is within the selection boolean within = (textRangeStartOffset <= selectionStart && textRangeEndOffset > selectionStart) && (textRangeStartOffset <= selectionEnd && textRangeEndOffset > selectionEnd); if (!isPrior && !isOverlap) return; - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); if (simpleName.equals("MarkdownHeaderImpl")) { PsiMarkdownContext content = new PsiMarkdownContext(section, text.trim(), element.getTextOffset()); while (content.headerLevel() <= section.headerLevel() && section.parent != null) { diff --git a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiUtil.java b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiUtil.java index f4d4d3f2..423f9981 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/psi/PsiUtil.java +++ b/src/main/java/com/github/simiacryptus/aicoder/psi/PsiUtil.java @@ -1,6 +1,6 @@ package com.github.simiacryptus.aicoder.psi; -import com.github.simiacryptus.aicoder.text.StringTools; +import com.github.simiacryptus.aicoder.util.StringTools; import com.intellij.openapi.util.TextRange; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElementVisitor; @@ -28,7 +28,7 @@ public static PsiElement getLargestIntersectingComment(@NotNull PsiElement eleme * @param types The types of elements to search for. * @return The largest element that intersects with the given selection range. */ - public static PsiElement getLargestIntersecting(@NotNull PsiElement element, int selectionStart, int selectionEnd, String... types) { + public static PsiElement getLargestIntersecting(@NotNull PsiElement element, int selectionStart, int selectionEnd, CharSequence... types) { final AtomicReference largest = new AtomicReference<>(null); final AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { @@ -37,7 +37,7 @@ public void visitElement(@NotNull PsiElement element) { if (null == element) return; TextRange textRange = element.getTextRange(); boolean within = (textRange.getStartOffset() <= selectionStart && textRange.getEndOffset() + 1 >= selectionStart && textRange.getStartOffset() <= selectionEnd && textRange.getEndOffset() + 1 >= selectionEnd); - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); if (Arrays.asList(expand(types)).contains(simpleName)) { if (within) { largest.updateAndGet(s -> (s == null ? 0 : s.getText().length()) > element.getText().length() ? s : element); @@ -50,7 +50,7 @@ public void visitElement(@NotNull PsiElement element) { element.accept(visitor.get()); return largest.get(); } - public static List getAll(@NotNull PsiElement element, String... types) { + public static List getAll(@NotNull PsiElement element, CharSequence... types) { final List elements = new ArrayList<>(); final AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { @@ -90,7 +90,7 @@ public static PsiElement getSmallestIntersecting(@NotNull PsiElement element, in * @param types The types of the elements to be retrieved. * @return The smallest intersecting entity from the given PsiElement. */ - public static PsiElement getSmallestIntersecting(@NotNull PsiElement element, int selectionStart, int selectionEnd, String... types) { + public static PsiElement getSmallestIntersecting(@NotNull PsiElement element, int selectionStart, int selectionEnd, CharSequence... types) { final AtomicReference largest = new AtomicReference<>(null); final AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { @@ -99,7 +99,7 @@ public void visitElement(@NotNull PsiElement element) { if (null == element) return; TextRange textRange = element.getTextRange(); boolean within = (textRange.getStartOffset() <= selectionStart && textRange.getEndOffset() + 1 >= selectionStart && textRange.getStartOffset() <= selectionEnd && textRange.getEndOffset() + 1 >= selectionEnd); - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); if (Arrays.asList(expand(types)).contains(simpleName)) { if (within) { largest.updateAndGet(s -> (s == null ? Integer.MAX_VALUE : s.getText().length()) < element.getText().length() ? s : element); @@ -114,11 +114,11 @@ public void visitElement(@NotNull PsiElement element) { return largest.get(); } - private static String[] expand(String[] types) { - return Arrays.stream(types).flatMap(x-> Stream.of(x, StringTools.stripSuffix(x, "Impl"))).distinct().toArray(String[]::new); + private static CharSequence[] expand(CharSequence[] types) { + return Arrays.stream(types).flatMap(x-> Stream.of(x, StringTools.stripSuffix(x, "Impl"))).distinct().toArray(CharSequence[]::new); } - public static PsiElement getFirstBlock(@NotNull PsiElement element, String blockType) { + public static PsiElement getFirstBlock(@NotNull PsiElement element, CharSequence blockType) { PsiElement[] children = element.getChildren(); if(null == children || 0 == children.length) return null; PsiElement first = children[0]; @@ -126,13 +126,13 @@ public static PsiElement getFirstBlock(@NotNull PsiElement element, String block return null; } - public static PsiElement getLargestBlock(@NotNull PsiElement element, String blockType) { + public static PsiElement getLargestBlock(@NotNull PsiElement element, CharSequence blockType) { AtomicReference largest = new AtomicReference<>(null); AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { @Override public void visitElement(@NotNull PsiElement element) { - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); if (simpleName.equals(blockType)) { largest.updateAndGet(s -> s != null && s.getText().length() > element.getText().length() ? s : element); super.visitElement(element); @@ -155,13 +155,13 @@ public void visitElement(@NotNull PsiElement element) { * @return A {@link HashSet} of {@link String}s containing the simple names of all the {@link PsiElement}s contained * within the given {@link PsiElement}. */ - public static @NotNull HashSet getAllElementNames(@NotNull PsiElement element) { - HashSet set = new HashSet<>(); + public static @NotNull HashSet getAllElementNames(@NotNull PsiElement element) { + HashSet set = new HashSet<>(); AtomicReference visitor = new AtomicReference<>(); visitor.set(new PsiElementVisitor() { @Override public void visitElement(@NotNull PsiElement element) { - String simpleName = element.getClass().getSimpleName(); + CharSequence simpleName = element.getClass().getSimpleName(); set.add(simpleName); element.acceptChildren(visitor.get()); } diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/StringTools.java b/src/main/java/com/github/simiacryptus/aicoder/text/StringTools.java deleted file mode 100644 index 73732d69..00000000 --- a/src/main/java/com/github/simiacryptus/aicoder/text/StringTools.java +++ /dev/null @@ -1,150 +0,0 @@ -package com.github.simiacryptus.aicoder.text; - -import org.jetbrains.annotations.NotNull; - -import java.util.Arrays; -import java.util.Comparator; -import java.util.concurrent.atomic.AtomicInteger; - -public class StringTools { - - /* - * - * Strips unbalanced terminators from a given input string. - * - * @param input The input string to strip unbalanced terminators from. - * @return The input string with unbalanced terminators removed. - * @throws IllegalArgumentException If the input string is unbalanced. - */ - public static String stripUnbalancedTerminators(String input) { - int openCount = 0; - boolean inQuotes = false; - StringBuilder output = new StringBuilder(); - for (int i = 0; i < input.length(); i++) { - char c = input.charAt(i); - if (c == '"') { - inQuotes = !inQuotes; - } else if (inQuotes && c == '\\') { - // Skip the next character - i++; - } else if (!inQuotes) { - switch (c) { - case '{': - case '[': - case '(': - openCount++; - break; - case '}': - case ']': - case ')': - openCount--; - break; - } - } - if (openCount >= 0) { - output.append(c); - } else { - openCount++; // Dropping character - } - } - if (openCount != 0) { - throw new IllegalArgumentException("Unbalanced input"); - } - return output.toString(); - } - - public static @NotNull String stripPrefix(@NotNull String text, @NotNull String prefix) { - boolean startsWith = text.startsWith(prefix); - if (startsWith) { - return text.substring(prefix.length()); - } else { - return text; - } - } - - public static @NotNull String trimPrefix(@NotNull String text) { - String prefix = getWhitespacePrefix(text); - return stripPrefix(text, prefix); - } - - public static @NotNull String trimSuffix(@NotNull String text) { - String suffix = getWhitespaceSuffix(text); - return stripSuffix(text, suffix); - } - - public static @NotNull String stripSuffix(@NotNull String text, @NotNull String suffix) { - boolean endsWith = text.endsWith(suffix); - if (endsWith) { - return text.substring(0, text.length() - suffix.length()); - } else { - return text; - } - } - - public static String lineWrapping(String description, int width) { - StringBuilder output = new StringBuilder(); - String[] lines = description.split("\n"); - int lineLength = 0; - for (String line : lines) { - AtomicInteger sentenceLength = new AtomicInteger(lineLength); - String sentanceBuffer = wrapSentence(line, width, sentenceLength); - if (lineLength + sentanceBuffer.length() > width && sentanceBuffer.length() < width) { - output.append("\n"); - lineLength = 0; - sentenceLength.set(lineLength); - sentanceBuffer = wrapSentence(line, width, sentenceLength); - } else { - output.append(" "); - sentenceLength.addAndGet(1); - } - output.append(sentanceBuffer); - lineLength = sentenceLength.get(); - } - return output.toString(); - } - - private static String wrapSentence(String line, int width, AtomicInteger xPointer) { - StringBuilder sentenceBuffer = new StringBuilder(); - String[] words = line.split(" "); - for (String word : words) { - if (xPointer.get() + word.length() > width) { - sentenceBuffer.append("\n"); - xPointer.set(0); - } else { - sentenceBuffer.append(" "); - xPointer.addAndGet(1); - } - sentenceBuffer.append(word); - xPointer.addAndGet(word.length()); - } - return sentenceBuffer.toString(); - } - - public static String toString(int[] ints) { - char[] chars = new char[ints.length]; - for (int i = 0; i < ints.length; i++) { - chars[i] = (char) ints[i]; - } - return String.valueOf(chars); - } - - @NotNull - public static String getWhitespacePrefix(String... lines) { - return Arrays.stream(lines) - .map(l -> toString(l.chars().takeWhile(i -> Character.isWhitespace(i)).toArray())) - .filter(x->!x.isEmpty()) - .min(Comparator.comparing(x -> x.length())).orElse(""); - } - - @NotNull - public static String getWhitespaceSuffix(String... lines) { - return reverse(Arrays.stream(lines) - .map(StringTools::reverse) - .map(l -> toString(l.chars().takeWhile(i -> Character.isWhitespace(i)).toArray())) - .max(Comparator.comparing(x -> x.length())).orElse("")).toString(); - } - - private static CharSequence reverse(String l) { - return new StringBuffer(l).reverse().toString(); - } -} diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/BlockComment.java b/src/main/java/com/github/simiacryptus/aicoder/util/BlockComment.java similarity index 64% rename from src/main/java/com/github/simiacryptus/aicoder/text/BlockComment.java rename to src/main/java/com/github/simiacryptus/aicoder/util/BlockComment.java index 16a979ad..21b0cbce 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/text/BlockComment.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/BlockComment.java @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.text; +package com.github.simiacryptus.aicoder.util; import org.jetbrains.annotations.NotNull; @@ -20,13 +20,13 @@ public Factory(String blockPrefix, String linePrefix, String blockSuffix) { @Override public BlockComment fromString(String text) { text = StringTools.stripSuffix(StringTools.trimSuffix(text.replace("\t", TAB_REPLACEMENT)), blockSuffix.trim()); - String indent = StringTools.getWhitespacePrefix(text.split(DELIMITER)); + @NotNull CharSequence indent = StringTools.getWhitespacePrefix(text.split(DELIMITER)); return new BlockComment(blockPrefix, linePrefix, blockSuffix, indent, Arrays.stream(text.split(DELIMITER)) .map(s -> StringTools.stripPrefix(s, indent)) .map(StringTools::trimPrefix) .map(s -> StringTools.stripPrefix(s, blockPrefix.trim())) .map(s -> StringTools.stripPrefix(s, linePrefix.trim())) - .toArray(String[]::new)); + .toArray(CharSequence[]::new)); } @Override @@ -35,12 +35,12 @@ public boolean looksLike(String text) { } } - public final String linePrefix; - public final String blockPrefix; - public final String blockSuffix; + public final CharSequence linePrefix; + public final CharSequence blockPrefix; + public final CharSequence blockSuffix; - public BlockComment(String blockPrefix, String linePrefix, String blockSuffix, String indent, String... textBlock) { + public BlockComment(CharSequence blockPrefix, CharSequence linePrefix, CharSequence blockSuffix, CharSequence indent, CharSequence... textBlock) { super(indent, textBlock); this.linePrefix = linePrefix; this.blockPrefix = blockPrefix; @@ -49,13 +49,13 @@ public BlockComment(String blockPrefix, String linePrefix, String blockSuffix, S @Override public @NotNull String toString() { - String indent = getIndent(); - String delimiter = DELIMITER + indent; - String joined = Arrays.stream(rawString()).map(x->linePrefix + " " + x).collect(Collectors.joining(delimiter)); - return blockPrefix + delimiter + joined + delimiter + blockSuffix; + CharSequence indent = getIndent(); + CharSequence delimiter = DELIMITER + indent; + CharSequence joined = Arrays.stream(rawString()).map(x->linePrefix + " " + x).collect(Collectors.joining(delimiter)); + return blockPrefix.toString() + delimiter + joined + delimiter + blockSuffix; } - public @NotNull IndentedText withIndent(String indent) { + public @NotNull BlockComment withIndent(CharSequence indent) { return new BlockComment(blockPrefix, linePrefix, blockSuffix, indent, textBlock); } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/IndentedText.java b/src/main/java/com/github/simiacryptus/aicoder/util/IndentedText.java similarity index 68% rename from src/main/java/com/github/simiacryptus/aicoder/text/IndentedText.java rename to src/main/java/com/github/simiacryptus/aicoder/util/IndentedText.java index 57b98eb8..64ccbdb1 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/text/IndentedText.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/IndentedText.java @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.text; +package com.github.simiacryptus.aicoder.util; import org.jetbrains.annotations.NotNull; @@ -19,7 +19,7 @@ */ public class IndentedText implements TextBlock { - public String getIndent() { + public CharSequence getIndent() { return indent; } @@ -35,21 +35,21 @@ public boolean looksLike(String text) { } } - protected String indent; - protected String textBlock[]; + protected CharSequence indent; + protected CharSequence textBlock[]; - public IndentedText(String indent, String... textBlock) { + public IndentedText(CharSequence indent, CharSequence... textBlock) { this.indent = indent; this.textBlock = textBlock; } - public static @NotNull IndentedText fromString(String text) { - text = text.replace("\t", TAB_REPLACEMENT); - String indent = StringTools.getWhitespacePrefix(text.split(DELIMITER)); + public static @NotNull IndentedText fromString(CharSequence text) { + text = text.toString().replace("\t", TAB_REPLACEMENT); + @NotNull CharSequence indent = StringTools.getWhitespacePrefix(text.toString().split(DELIMITER)); return new IndentedText(indent, - Arrays.stream(text.split(DELIMITER)) + Arrays.stream(text.toString().split(DELIMITER)) .map(s -> StringTools.stripPrefix(s, indent)) - .toArray(String[]::new)); + .toArray(CharSequence[]::new)); } @Override @@ -57,12 +57,12 @@ public IndentedText(String indent, String... textBlock) { return Arrays.stream(rawString()).collect(Collectors.joining(DELIMITER + getIndent())); } - public @NotNull IndentedText withIndent(String indent) { + public @NotNull IndentedText withIndent(CharSequence indent) { return new IndentedText(indent, textBlock); } @Override - public String[] rawString() { + public CharSequence[] rawString() { return this.textBlock; } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/LineComment.java b/src/main/java/com/github/simiacryptus/aicoder/util/LineComment.java similarity index 75% rename from src/main/java/com/github/simiacryptus/aicoder/text/LineComment.java rename to src/main/java/com/github/simiacryptus/aicoder/util/LineComment.java index 16914d52..21098d7c 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/text/LineComment.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/LineComment.java @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.text; +package com.github.simiacryptus.aicoder.util; import org.jetbrains.annotations.NotNull; @@ -16,12 +16,12 @@ public Factory(String commentPrefix) { @Override public LineComment fromString(String text) { text = text.replace("\t", TAB_REPLACEMENT); - String indent = StringTools.getWhitespacePrefix(text.split(DELIMITER)); + CharSequence indent = StringTools.getWhitespacePrefix(text.split(DELIMITER)); return new LineComment(commentPrefix, indent, Arrays.stream(text.split(DELIMITER)) .map(s -> StringTools.stripPrefix(s, indent)) .map(StringTools::trimPrefix) .map(s -> StringTools.stripPrefix(s, commentPrefix)) - .toArray(String[]::new)); + .toArray(CharSequence[]::new)); } @Override @@ -30,9 +30,9 @@ public boolean looksLike(String text) { } } - private final String commentPrefix; + private final CharSequence commentPrefix; - public LineComment(String commentPrefix, String indent, String... textBlock) { + public LineComment(CharSequence commentPrefix, CharSequence indent, CharSequence... textBlock) { super(indent, textBlock); this.commentPrefix = commentPrefix; } @@ -42,7 +42,7 @@ public LineComment(String commentPrefix, String indent, String... textBlock) { return commentPrefix + " " + Arrays.stream(rawString()).collect(Collectors.joining(DELIMITER + getIndent() + commentPrefix + " ")); } - public @NotNull IndentedText withIndent(String indent) { + public @NotNull LineComment withIndent(CharSequence indent) { return new LineComment(commentPrefix, indent, textBlock); } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/util/StringTools.java b/src/main/java/com/github/simiacryptus/aicoder/util/StringTools.java new file mode 100644 index 00000000..e0c09a24 --- /dev/null +++ b/src/main/java/com/github/simiacryptus/aicoder/util/StringTools.java @@ -0,0 +1,270 @@ +package com.github.simiacryptus.aicoder.util; + +import org.jetbrains.annotations.NotNull; + +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class StringTools { + + /* + * + * Strips unbalanced terminators from a given input string. + * + * @param input The input string to strip unbalanced terminators from. + * @return The input string with unbalanced terminators removed. + * @throws IllegalArgumentException If the input string is unbalanced. + */ + public static CharSequence stripUnbalancedTerminators(CharSequence input) { + int openCount = 0; + boolean inQuotes = false; + StringBuilder output = new StringBuilder(); + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + if (c == '"') { + inQuotes = !inQuotes; + } else if (inQuotes && c == '\\') { + // Skip the next character + i++; + } else if (!inQuotes) { + switch (c) { + case '{': + case '[': + case '(': + openCount++; + break; + case '}': + case ']': + case ')': + openCount--; + break; + } + } + if (openCount >= 0) { + output.append(c); + } else { + openCount++; // Dropping character + } + } + if (openCount != 0) { + throw new IllegalArgumentException("Unbalanced input"); + } + return output.toString(); + } + + public static @NotNull String stripPrefix(@NotNull CharSequence text, @NotNull CharSequence prefix) { + boolean startsWith = text.toString().startsWith(prefix.toString()); + if (startsWith) { + return text.toString().substring(prefix.length()); + } else { + return text.toString(); + } + } + + public static @NotNull CharSequence trimPrefix(@NotNull CharSequence text) { + @NotNull CharSequence prefix = getWhitespacePrefix(text); + return stripPrefix(text, prefix); + } + + public static @NotNull String trimSuffix(@NotNull CharSequence text) { + String suffix = getWhitespaceSuffix(text); + return stripSuffix(text, suffix); + } + + public static @NotNull String stripSuffix(@NotNull CharSequence text, @NotNull CharSequence suffix) { + boolean endsWith = text.toString().endsWith(suffix.toString()); + if (endsWith) { + return text.toString().substring(0, text.length() - suffix.length()); + } else { + return text.toString(); + } + } + + public static String lineWrapping(CharSequence description, int width) { + StringBuilder output = new StringBuilder(); + String[] lines = description.toString().split("\n"); + int lineLength = 0; + for (String line : lines) { + AtomicInteger sentenceLength = new AtomicInteger(lineLength); + String sentanceBuffer = wrapSentence(line, width, sentenceLength); + if (lineLength + sentanceBuffer.length() > width && sentanceBuffer.length() < width) { + output.append("\n"); + lineLength = 0; + sentenceLength.set(lineLength); + sentanceBuffer = wrapSentence(line, width, sentenceLength); + } else { + output.append(" "); + sentenceLength.addAndGet(1); + } + output.append(sentanceBuffer); + lineLength = sentenceLength.get(); + } + return output.toString(); + } + + private static String wrapSentence(CharSequence line, int width, AtomicInteger xPointer) { + StringBuilder sentenceBuffer = new StringBuilder(); + String[] words = line.toString().split(" "); + for (String word : words) { + if (xPointer.get() + word.length() > width) { + sentenceBuffer.append("\n"); + xPointer.set(0); + } else { + sentenceBuffer.append(" "); + xPointer.addAndGet(1); + } + sentenceBuffer.append(word); + xPointer.addAndGet(word.length()); + } + return sentenceBuffer.toString(); + } + + public static CharSequence toString(int[] ints) { + char[] chars = new char[ints.length]; + for (int i = 0; i < ints.length; i++) { + chars[i] = (char) ints[i]; + } + return String.valueOf(chars); + } + + @NotNull + public static CharSequence getWhitespacePrefix(CharSequence... lines) { + return Arrays.stream(lines) + .map(l -> toString(l.chars().takeWhile(i -> Character.isWhitespace(i)).toArray())) + .filter(x -> x.length()>0) + .min(Comparator.comparing(x -> x.length())).orElse(""); + } + + @NotNull + public static String getWhitespaceSuffix(CharSequence... lines) { + return reverse(Arrays.stream(lines) + .map(StringTools::reverse) + .map(l -> toString(l.chars().takeWhile(i -> Character.isWhitespace(i)).toArray())) + .max(Comparator.comparing(x -> x.length())).orElse("")).toString(); + } + + private static CharSequence reverse(CharSequence l) { + return new StringBuffer(l).reverse().toString(); + } + + @NotNull + public static List trim(List items, int max, boolean preserveHead) { + items = new ArrayList<>(items); + Random random = new Random(); + while (items.size() > max) { + int index = random.nextInt(items.size()); + if (preserveHead && index == 0) continue; + items.remove(index); + } + return items; + } + + public static String transposeMarkdownTable(String table, boolean inputHeader, boolean outputHeader) { + String[][] cells = parseMarkdownTable(table, inputHeader); + StringBuilder transposedTable = new StringBuilder(); + int columns = cells[0].length; + int rows = cells.length; + if (outputHeader) columns = columns + 1; + for (int column = 0; column < columns; column++) { + transposedTable.append("|"); + for (int row = 0; row < rows; row++) { + String cellValue; + String[] rowCells = cells[row]; + if (outputHeader) { + if (column < 1) { + cellValue = rowCells[column].trim(); + } else if (column == 1) { + cellValue = "---"; + } else if ((column - 1) >= rowCells.length) { + cellValue = ""; + } else { + cellValue = rowCells[column - 1].trim(); + } + } else { + cellValue = rowCells[column].trim(); + } + transposedTable.append(" ").append(cellValue).append(" |"); + } + transposedTable.append("\n"); + } + return transposedTable.toString(); + } + + private static String[][] parseMarkdownTable(String table, boolean removeHeader) { + ArrayList rows = new ArrayList(Arrays.stream(table.split("\n")).map(x -> Arrays.stream(x.split("\\|")).filter(cell -> cell.length() > 0).toArray(CharSequence[]::new)).collect(Collectors.toList())); + if (removeHeader) { + rows.remove(1); + } + return rows.stream() + //.filter(x -> x.length == rows.get(0).length) + .toArray(String[][]::new); + } + + public static CharSequence getPrefixForContext(String text) { + return getPrefixForContext(text, 512, ".", "\n", ",", ";"); + } + + /** + * Get the prefix for the given context. + * + * @param text The text to get the prefix from. + * @param idealLength The ideal length of the prefix. + * @param delimiters The delimiters to split the text by. + * @return The prefix for the given context. + */ + public static CharSequence getPrefixForContext(String text, int idealLength, CharSequence... delimiters) { + List candidates = Stream.of(delimiters).flatMap(d -> { + StringBuilder sb = new StringBuilder(); + String[] split = text.split(Pattern.quote(d.toString())); + for (int i = 0; i < split.length; i++) { + String s = split[i]; + if (Math.abs(sb.length() - idealLength) < Math.abs((sb.length() + s.length()) - idealLength)) break; + if (sb.length() > 0) sb.append(d); + sb.append(s); + if (sb.length() > idealLength) break; + } + if (split.length == 0) return Stream.empty(); + return Stream.of(sb.toString()); + }).collect(Collectors.toList()); + Optional winner = candidates.stream().min(Comparator.comparing(s -> Math.abs(s.length() - idealLength))); + return winner.get(); + } + + public static CharSequence getSuffixForContext(String text) { + return getSuffixForContext(text, 512, ".", "\n", ",", ";"); + } + + /** + * + * Get the suffix for the given context. + * + * @param text The text to get the suffix from. + * @param idealLength The ideal length of the suffix. + * @param delimiters The delimiters to use when splitting the text. + * @return The suffix for the given context. + */ + @NotNull + public static CharSequence getSuffixForContext(String text, int idealLength, CharSequence... delimiters) { + List candidates = Stream.of(delimiters).flatMap(d -> { + StringBuilder sb = new StringBuilder(); + String[] split = text.split(Pattern.quote(d.toString())); + for (int i = split.length - 1; i >= 0; i--) { + String s = split[i]; + if (Math.abs(sb.length() - idealLength) < Math.abs((sb.length() + s.length()) - idealLength)) break; + if (sb.length() > 0 || text.endsWith(d.toString())) sb.insert(0, d); + sb.insert(0, s); + if (sb.length() > idealLength) { + //if (i > 0) sb.insert(0, d); + break; + } + } + if (split.length == 0) return Stream.empty(); + return Stream.of(sb.toString()); + }).collect(Collectors.toList()); + Optional winner = candidates.stream().min(Comparator.comparing(s -> Math.abs(s.length() - idealLength))); + return winner.get(); + } +} diff --git a/src/main/java/com/github/simiacryptus/aicoder/StyleUtil.java b/src/main/java/com/github/simiacryptus/aicoder/util/StyleUtil.java similarity index 61% rename from src/main/java/com/github/simiacryptus/aicoder/StyleUtil.java rename to src/main/java/com/github/simiacryptus/aicoder/util/StyleUtil.java index 8cbec3a1..47a28eb5 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/StyleUtil.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/StyleUtil.java @@ -1,13 +1,13 @@ -package com.github.simiacryptus.aicoder; +package com.github.simiacryptus.aicoder.util; +import com.github.simiacryptus.aicoder.ComputerLanguage; import com.github.simiacryptus.aicoder.config.AppSettingsState; -import com.github.simiacryptus.aicoder.openai.ModerationException; -import com.github.simiacryptus.aicoder.text.IndentedText; -import com.github.simiacryptus.aicoder.text.StringTools; +import com.github.simiacryptus.aicoder.openai.CompletionRequest; +import com.github.simiacryptus.aicoder.openai.OpenAI_API; +import com.google.common.util.concurrent.ListenableFuture; import com.intellij.openapi.diagnostic.Logger; import javax.swing.*; -import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -18,7 +18,7 @@ public class StyleUtil { /** * A list of style keywords used to describe the type of writing. */ - private static final List styleKeywords = Arrays.asList( + private static final List styleKeywords = Arrays.asList( "Analytical", "Casual", "Comic-book", @@ -42,7 +42,7 @@ public class StyleUtil { /** * A list of dialect keywords for use in writing. */ - private static final List dialectKeywords = Arrays.asList( + private static final List dialectKeywords = Arrays.asList( "Academic Writing", "Business Writing", "Character Monologues", @@ -98,14 +98,14 @@ public class StyleUtil { * @return A string in the format of "Dialect - Casual, Inspirational" */ public static String randomStyle() { - String dialect = dialectKeywords.get(new Random().nextInt(dialectKeywords.size())); - String style1 = styleKeywords.get(new Random().nextInt(styleKeywords.size())); - String style2 = style1; + CharSequence dialect = dialectKeywords.get(new Random().nextInt(dialectKeywords.size())); + CharSequence style1 = styleKeywords.get(new Random().nextInt(styleKeywords.size())); + CharSequence style2 = style1; while (style2.equals(style1)) style2 = styleKeywords.get(new Random().nextInt(styleKeywords.size())); return String.format("%s - %s, %s", dialect, style1, style2); } - public static void demoStyle(String style) { + public static void demoStyle(CharSequence style) { demoStyle(style, ComputerLanguage.Java, "List items = new ArrayList<>();\n" + @@ -124,39 +124,33 @@ public static void demoStyle(String style) { * @param language The language of the code snippet. * @param code The code snippet to be described. */ - public static void demoStyle(String style, ComputerLanguage language, String code) { - String codeDescription = describeTest(style, language, code); - String message = String.format("This code:\n %s\nwas described as:\n %s", code.replace("\n", "\n "), codeDescription.replace("\n", "\n ")); - JOptionPane.showMessageDialog(null, message, "Style Demo", JOptionPane.INFORMATION_MESSAGE); + public static void demoStyle(CharSequence style, ComputerLanguage language, String code) { + OpenAI_API.onSuccess(describeTest(style, language, code), description -> { + CharSequence message = String.format("This code:\n %s\nwas described as:\n %s", code.replace("\n", "\n "), description.toString().replace("\n", "\n ")); + JOptionPane.showMessageDialog(null, message, "Style Demo", JOptionPane.INFORMATION_MESSAGE); + }); } /** * Describes some test code in the specified style and language. * - * @param style The style of the description. - * @param language The language of the test. - * @param code The code. + * @param style The style of the description. + * @param language The language of the test. + * @param code The code. * @return A description of the test in the specified style and language. */ - public static String describeTest(String style, ComputerLanguage language, String code) { + public static ListenableFuture describeTest(CharSequence style, ComputerLanguage language, String code) { AppSettingsState settings = AppSettingsState.getInstance(); - try { - return StringTools.lineWrapping(settings.createTranslationRequest() - .setInstruction(String.format("Explain this %s in %s (%s)", language.name(), settings.humanLanguage, style)) - .setInputText(IndentedText.fromString(code).getTextBlock().trim()) - .setInputType(language.name()) - .setInputAttribute("type", "code") - .setOutputType(settings.humanLanguage) - .setOutputAttrute("type", "description") - .setOutputAttrute("style", style) - .buildCompletionRequest() - .complete("") - .trim(), 120); - } catch (ModerationException e) { - return e.getMessage(); - } catch (IOException e) { - log.error(e); - return e.getMessage(); - } + CompletionRequest completionRequest = settings.createTranslationRequest() + .setInstruction(String.format("Explain this %s in %s (%s)", language.name(), settings.humanLanguage, style)) + .setInputText(IndentedText.fromString(code).getTextBlock().trim()) + .setInputType(language.name()) + .setInputAttribute("type", "code") + .setOutputType(settings.humanLanguage) + .setOutputAttrute("type", "description") + .setOutputAttrute("style", style) + .buildCompletionRequest(); + ListenableFuture future = completionRequest.complete(null, ""); + return OpenAI_API.map(future, x->StringTools.lineWrapping(x.toString().trim(), 120)); } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/TextBlock.java b/src/main/java/com/github/simiacryptus/aicoder/util/TextBlock.java similarity index 60% rename from src/main/java/com/github/simiacryptus/aicoder/text/TextBlock.java rename to src/main/java/com/github/simiacryptus/aicoder/util/TextBlock.java index c26ea214..7bbe3b0f 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/text/TextBlock.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/TextBlock.java @@ -1,4 +1,4 @@ -package com.github.simiacryptus.aicoder.text; +package com.github.simiacryptus.aicoder.util; import org.jetbrains.annotations.NotNull; @@ -8,18 +8,18 @@ public interface TextBlock { - public static final String TAB_REPLACEMENT = " "; + public static final CharSequence TAB_REPLACEMENT = " "; public static final String DELIMITER = "\n"; - String[] rawString(); + CharSequence[] rawString(); default String getTextBlock() { return Arrays.stream(rawString()).collect(Collectors.joining(DELIMITER)); } - @NotNull TextBlock withIndent(String indent); + @NotNull TextBlock withIndent(CharSequence indent); - default Stream stream() { + default Stream stream() { return Arrays.stream(rawString()); } } diff --git a/src/main/java/com/github/simiacryptus/aicoder/text/TextBlockFactory.java b/src/main/java/com/github/simiacryptus/aicoder/util/TextBlockFactory.java similarity index 62% rename from src/main/java/com/github/simiacryptus/aicoder/text/TextBlockFactory.java rename to src/main/java/com/github/simiacryptus/aicoder/util/TextBlockFactory.java index 285de95a..9eb29abf 100644 --- a/src/main/java/com/github/simiacryptus/aicoder/text/TextBlockFactory.java +++ b/src/main/java/com/github/simiacryptus/aicoder/util/TextBlockFactory.java @@ -1,8 +1,8 @@ -package com.github.simiacryptus.aicoder.text; +package com.github.simiacryptus.aicoder.util; public interface TextBlockFactory { T fromString(String text); - default String toString(T text) { + default CharSequence toString(T text) { return text.toString(); } boolean looksLike(String text); diff --git a/src/main/java/com/github/simiacryptus/aicoder/util/TextReplacementAction.java b/src/main/java/com/github/simiacryptus/aicoder/util/TextReplacementAction.java new file mode 100644 index 00000000..d3781557 --- /dev/null +++ b/src/main/java/com/github/simiacryptus/aicoder/util/TextReplacementAction.java @@ -0,0 +1,77 @@ +package com.github.simiacryptus.aicoder.util; + +import com.github.simiacryptus.aicoder.openai.CompletionRequest; +import com.github.simiacryptus.aicoder.openai.ModerationException; +import com.intellij.openapi.actionSystem.AnAction; +import com.intellij.openapi.actionSystem.AnActionEvent; +import com.intellij.openapi.actionSystem.CommonDataKeys; +import com.intellij.openapi.editor.Caret; +import com.intellij.openapi.editor.CaretModel; +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.util.NlsActions; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import javax.swing.*; +import java.io.IOException; + +import static com.github.simiacryptus.aicoder.util.UITools.replaceString; + +/** + * TextReplacementAction is an abstract class that extends the AnAction class. + * It provides a static method create() that takes in a text, description, icon, and an ActionTextEditorFunction. + * It also provides an actionPerformed() method that is called when the action is performed. + * This method gets the editor, caret model, and primary caret from the AnActionEvent. + * It then calls the edit() method, which is implemented by the subclasses, and replaces the selected text with the new text. + * The ActionTextEditorFunction is a functional interface that takes in an AnActionEvent and a String and returns a String. + */ +public class TextReplacementAction extends AnAction { + + private final ActionTextEditorFunction fn; + + private TextReplacementAction(@Nullable @NlsActions.ActionText CharSequence text, @Nullable @NlsActions.ActionDescription CharSequence description, @Nullable Icon icon, @NotNull ActionTextEditorFunction fn) { + super(text.toString(), description.toString(), icon); + this.fn = fn; + } + + public static @NotNull TextReplacementAction create(@Nullable @NlsActions.ActionText CharSequence text, @Nullable @NlsActions.ActionDescription CharSequence description, @Nullable Icon icon, @NotNull ActionTextEditorFunction fn) { + return new TextReplacementAction(text, description, icon, fn); + } + + @Override + public void actionPerformed(@NotNull final AnActionEvent e) { + final Editor editor = e.getRequiredData(CommonDataKeys.EDITOR); + final CaretModel caretModel = editor.getCaretModel(); + final Caret primaryCaret = caretModel.getPrimaryCaret(); + try { + int selectionStart = primaryCaret.getSelectionStart(); + int selectionEnd = primaryCaret.getSelectionEnd(); + String selectedText = primaryCaret.getSelectedText(); + CompletionRequest request = fn.apply(e, selectedText); + Caret caret = e.getData(CommonDataKeys.CARET); + CharSequence indent = UITools.getIndent(caret); + UITools.redoableRequest(request, indent, e, (CharSequence x) -> { + CharSequence newText = fn.postTransform(e, selectedText, x); + return replaceString(editor.getDocument(), selectionStart, selectionEnd, newText); + }); + } catch (ModerationException | IOException ex) { + UITools.handle(ex); + } + } + + public interface ActionTextEditorFunction { + CompletionRequest apply(AnActionEvent actionEvent, String input) throws IOException, ModerationException; + + /** + * + * Override this method to post-transform the completion string. + * + * @param event The action event + * @param prompt The prompt string + * @param completion The completion string + * @return The transformed string + */ + default CharSequence postTransform(AnActionEvent event, CharSequence prompt, CharSequence completion) { return completion; } + } + +} diff --git a/src/main/java/com/github/simiacryptus/aicoder/util/UITools.java b/src/main/java/com/github/simiacryptus/aicoder/util/UITools.java new file mode 100644 index 00000000..32af4f43 --- /dev/null +++ b/src/main/java/com/github/simiacryptus/aicoder/util/UITools.java @@ -0,0 +1,224 @@ +package com.github.simiacryptus.aicoder.util; + +import com.github.simiacryptus.aicoder.config.AppSettingsState; +import com.github.simiacryptus.aicoder.openai.CompletionRequest; +import com.github.simiacryptus.aicoder.openai.ModerationException; +import com.github.simiacryptus.aicoder.openai.OpenAI_API; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.intellij.openapi.actionSystem.AnActionEvent; +import com.intellij.openapi.actionSystem.CommonDataKeys; +import com.intellij.openapi.command.WriteCommandAction; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.editor.Caret; +import com.intellij.openapi.editor.Document; +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.util.TextRange; +import org.jetbrains.annotations.NotNull; + +import javax.swing.*; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.function.Function; + +public class UITools { + + private static final Logger log = Logger.getInstance(UITools.class); + + public static final ConcurrentLinkedDeque retry = new ConcurrentLinkedDeque<>(); + + /** + * This method is responsible for making a redoable request. + * + * @param request The completion request to be made. + * @param indent The indentation to be used. + * @param event The project to be used. + * @param action The action to be taken when the request is completed. + * @return A {@link Runnable} that can be used to redo the request. + */ + public static void redoableRequest(CompletionRequest request, CharSequence indent, @NotNull AnActionEvent event, Function action) { + Editor editor = event.getData(CommonDataKeys.EDITOR); + Document document = editor.getDocument(); + //document.setReadOnly(true); + ProgressManager progressManager = ProgressManager.getInstance(); + ProgressIndicator progressIndicator = progressManager.getProgressIndicator(); + if(null != progressIndicator) { + progressIndicator.setIndeterminate(true); + progressIndicator.setText("Talking to OpenAI..."); + } + ListenableFuture resultFuture = request.complete(event.getProject(), indent); + Futures.addCallback(resultFuture, new FutureCallback() { + @Override + public void onSuccess(CharSequence result) { + //document.setReadOnly(false); + if(null != progressIndicator) { + progressIndicator.cancel(); + } + WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { + retry.add(getRetry(request, indent, event, action, action.apply(result.toString()))); + }); + } + + @Override + public void onFailure(Throwable t) { + //document.setReadOnly(false); + if(null != progressIndicator) { + progressIndicator.cancel(); + } + handle(t); + } + }, OpenAI_API.INSTANCE.pool); + } + + /** + * Get a retry for the given {@link CompletionRequest}. + * + *

This method will create a {@link Runnable} that will attempt to complete the given {@link CompletionRequest} + * with the given {@code indent}. If the completion is successful, the given {@code action} will be applied to the + * result and the given {@code undo} will be run. + * + * @param request the {@link CompletionRequest} to complete + * @param indent the indent to use for the completion + * @param event the {@link Project} to use for the completion + * @param action the {@link Function} to apply to the result of the completion + * @param undo the {@link Runnable} to run if the completion is successful + * @return a {@link Runnable} that will attempt to complete the given {@link CompletionRequest} + */ + @NotNull + private static Runnable getRetry(CompletionRequest request, CharSequence indent, AnActionEvent event, Function action, Runnable undo) { + Document document = event.getData(CommonDataKeys.EDITOR).getDocument(); + //document.setReadOnly(true); + ProgressIndicator progressIndicator = ProgressManager.getInstance().getProgressIndicator(); + if(null != progressIndicator) { + progressIndicator.setIndeterminate(true); + } + return () -> { + ListenableFuture retryFuture = request.complete(event.getProject(), indent); + Futures.addCallback(retryFuture, new FutureCallback() { + @Override + public void onSuccess(CharSequence result) { + WriteCommandAction.runWriteCommandAction(event.getProject(), () -> { + //document.setReadOnly(false); + if(null != progressIndicator) { + progressIndicator.cancel(); + } + if (null != undo) undo.run(); + retry.add(getRetry(request, indent, event, action, action.apply(result.toString()))); + }); + } + + @Override + public void onFailure(Throwable t) { + //document.setReadOnly(false); + if(null != progressIndicator) { + progressIndicator.cancel(); + } + handle(t); + } + }, OpenAI_API.INSTANCE.pool); + }; + } + + /** + * Get an instruction with a style + * + * @param instruction The instruction to be returned + * @return A string containing the instruction and the style + */ + public static String getInstruction(String instruction) { + CharSequence style = AppSettingsState.getInstance().style; + if (style.length() == 0) return instruction; + return String.format("%s (%s)", instruction, style); + } + + /** + * Replaces a string in a document with a new string. + * + * @param document The document to replace the string in. + * @param startOffset The start offset of the string to be replaced. + * @param endOffset The end offset of the string to be replaced. + * @param newText The new string to replace the old string. + * @return A Runnable that can be used to undo the replacement. + */ + public static Runnable replaceString(Document document, int startOffset, int endOffset, CharSequence newText) { + CharSequence oldText = document.getText(new TextRange(startOffset, endOffset)); + document.replaceString(startOffset, endOffset, newText); + return () -> { + if (!document.getText(new TextRange(startOffset, startOffset + newText.length())).equals(newText)) + throw new AssertionError(); + document.replaceString(startOffset, startOffset + newText.length(), oldText); + }; + } + + /** + * Inserts a string into a document at a given offset and returns a Runnable to undo the insertion. + * + * @param document The document to insert the string into. + * @param startOffset The offset at which to insert the string. + * @param newText The string to insert. + * @return A Runnable that can be used to undo the insertion. + */ + public static Runnable insertString(Document document, int startOffset, CharSequence newText) { + document.insertString(startOffset, newText); + return () -> { + if (!document.getText(new TextRange(startOffset, startOffset + newText.length())).equals(newText)) + throw new AssertionError(); + document.deleteString(startOffset, startOffset + newText.length()); + }; + } + + public static Runnable deleteString(Document document, int startOffset, int endOffset) { + CharSequence oldText = document.getText(new TextRange(startOffset, endOffset)); + document.deleteString(startOffset, endOffset); + return () -> { + document.insertString(startOffset, oldText); + }; + } + + public static CharSequence getIndent(Caret caret) { + if (null == caret) return ""; + Document document = caret.getEditor().getDocument(); + return IndentedText.fromString(document.getText().split("\n")[document.getLineNumber(caret.getSelectionStart())]).getIndent(); + } + + public static boolean hasSelection(@NotNull AnActionEvent e) { + Caret caret = e.getData(CommonDataKeys.CARET); + return null != caret && caret.hasSelection(); + } + + public static void handle(@NotNull Throwable ex) { + if (!(ex instanceof ModerationException)) log.error(ex); + JOptionPane.showMessageDialog(null, ex.getMessage(), "Warning", JOptionPane.WARNING_MESSAGE); + } + + public static CharSequence getIndent(AnActionEvent event) { + Caret caret = event.getData(CommonDataKeys.CARET); + CharSequence indent; + if (null == caret) { + indent = ""; + } else { + indent = getIndent(caret); + } + return indent; + } + + public static String queryAPIKey() { + JPanel panel = new JPanel(); + JLabel label = new JLabel("Enter OpenAI API Key:"); + JPasswordField pass = new JPasswordField(100); + panel.add(label); + panel.add(pass); + CharSequence[] options = new CharSequence[]{"OK", "Cancel"}; + int option = JOptionPane.showOptionDialog(null, panel, "API Key", + JOptionPane.NO_OPTION, JOptionPane.PLAIN_MESSAGE, + null, options, options[1]); + if (option == 0) { + char[] password = pass.getPassword(); + return new String(password); + } + return null; + } +}