Skip to content

Commit

Permalink
Merge pull request #210 from devchat-ai/improve-code-completion
Browse files Browse the repository at this point in the history
Improve code completion
  • Loading branch information
pplam authored Oct 9, 2024
2 parents d0ea7f7 + 6183876 commit 244143e
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 53 deletions.
33 changes: 33 additions & 0 deletions src/main/kotlin/ai/devchat/common/Constants.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,37 @@ package ai.devchat.common
object Constants {
val ASSISTANT_NAME_ZH = DevChatBundle.message("assistant.name.zh")
val ASSISTANT_NAME_EN = DevChatBundle.message("assistant.name.en")
val FUNC_TYPE_NAMES: Set<String> = setOf(
"FUN", // Kotlin
"METHOD", // Java
"FUNCTION_DEFINITION", // C, C++
"Py:FUNCTION_DECLARATION", // Python
"FUNCTION_DECLARATION", "METHOD_DECLARATION", // Golang
"JS:FUNCTION_DECLARATION", "JS:FUNCTION_EXPRESSION", // JS
"JS:TYPESCRIPT_FUNCTION", "JS:TYPESCRIPT_FUNCTION_EXPRESSION", // TS
"CLASS_METHOD", // PHP
"FUNCTION", // PHP, Rust
"Ruby:METHOD", // Ruby
)
val CALL_EXPRESSION_ELEMENT_TYPE_NAMES: Set<String> = setOf(
"CALL_EXPRESSION", // Kotlin, C, C++, Python
"METHOD_CALL_EXPRESSION", // Java
"CALL_EXPR", // Go, Rust
"JS_CALL_EXPRESSION", // JS
"TS_CALL_EXPRESSION", // TS
"PHP_METHOD_REFERENCE", // PHP
"CALL", // Ruby
)
val LANGUAGE_COMMENT_PREFIX: Map<String, String> = mapOf(
"kotlin" to "//",
"java" to "//",
"cpp" to "//",
"python" to "#",
"go" to "//",
"javascript" to "//",
"typescript" to "//",
"php" to "//", // PHP also supports `#` for comments
"rust" to "//",
"ruby" to "#"
)
}
208 changes: 208 additions & 0 deletions src/main/kotlin/ai/devchat/common/IDEUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package ai.devchat.common

import com.intellij.codeInsight.navigation.actions.TypeDeclarationProvider
import com.intellij.lang.folding.FoldingDescriptor
import com.intellij.lang.folding.LanguageFolding
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.ReadAction
import com.intellij.openapi.roots.ProjectFileIndex
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiNameIdentifierOwner
import com.intellij.psi.PsiPolyVariantReference
import com.intellij.psi.util.elementType
import com.intellij.psi.util.findParentInFile
import com.intellij.refactoring.suggested.startOffset
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CountDownLatch
import kotlin.system.measureTimeMillis


object IDEUtils {
fun <T> runInEdtAndGet(block: () -> T): T {
val app = ApplicationManager.getApplication()
if (app.isDispatchThread) {
return block()
}
val future = CompletableFuture<T>()
val latch = CountDownLatch(1)
app.invokeLater {
try {
val result = block()
future.complete(result)
} catch (e: Exception) {
future.completeExceptionally(e)
} finally {
latch.countDown()
}
}
latch.await()
return future.get()
}

fun findCalleeInParent(element: PsiElement?): List<PsiElement>? {
if (element == null) return null
Log.info("Find callee in parent: ${element.elementType}: ${element.text.replace("\n", "\\n")}")
val nearestCallExpression = element.findParentInFile(withSelf = true) {
if (it is PsiFile) false else {
it.elementType.toString() in Constants.CALL_EXPRESSION_ELEMENT_TYPE_NAMES
}
}

if (nearestCallExpression == null) return null

Log.info("Nearest call expression: ${nearestCallExpression.elementType}: ${nearestCallExpression.text.replace("\n", "\\n")}")

val projectFileIndex = ProjectFileIndex.getInstance(element.project)
val callee = nearestCallExpression.children.asSequence()
.mapNotNull {child ->
child.reference?.let{ref ->
if (ref is PsiPolyVariantReference) {
ref.multiResolve(false).mapNotNull { it.element }
} else listOfNotNull(ref.resolve())
}?.filter {
val containingFile = it.containingFile?.virtualFile
containingFile != null && projectFileIndex.isInContent(containingFile)
}
}
.firstOrNull {it.isNotEmpty()}

if (callee == null) {
Log.info("Callee not found")
} else {
Log.info("Callee: $callee")
}

return callee ?: findCalleeInParent(nearestCallExpression.parent)
}

fun PsiElement.findCalleeInParent(): Sequence<List<PsiElement>> {
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
return generateSequence(this) { it.parent }
.takeWhile { it !is PsiFile }
.filter { it.elementType.toString() in Constants.CALL_EXPRESSION_ELEMENT_TYPE_NAMES }
.mapNotNull { callExpression ->
Log.info("Call expression: ${callExpression.elementType}: ${callExpression.text}")

callExpression.children
.asSequence()
.mapNotNull { child ->
child.reference?.let { ref ->
if (ref is PsiPolyVariantReference) {
ref.multiResolve(false).mapNotNull { it.element }
} else {
listOfNotNull(ref.resolve())
}
.filter { resolved ->
resolved.containingFile.virtualFile?.let { file ->
projectFileIndex.isInContent(file)
} == true
}
}
}
.firstOrNull { it.isNotEmpty() }
}
}


private fun PsiElement.getTypeDeclaration(): PsiElement? = runBlocking(Dispatchers.IO) {
ReadAction.compute<PsiElement?, Throwable> {
TypeDeclarationProvider.EP_NAME.extensionList.asSequence()
.mapNotNull { provider ->
provider.getSymbolTypeDeclarations(this@getTypeDeclaration)?.firstOrNull()
}
.firstOrNull()
}
}

data class CodeNode(
val element: PsiElement,
val isProjectContent: Boolean,
)
data class SymbolTypeDeclaration(
val symbol: PsiNameIdentifierOwner,
val typeDeclaration: CodeNode
)

fun PsiElement.findAccessibleVariables(): Sequence<SymbolTypeDeclaration> {
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
return generateSequence(this.parent) { it.parent }
.takeWhile { it !is PsiFile }
.flatMap { it.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>() }
.plus(this.containingFile.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>())
.filter { !it.name.isNullOrEmpty() && it.nameIdentifier != null }
.mapNotNull {
val typeDeclaration = it.getTypeDeclaration() ?: return@mapNotNull null
val virtualFile = typeDeclaration.containingFile.virtualFile ?: return@mapNotNull null
val isProjectContent = projectFileIndex.isInContent(virtualFile)
SymbolTypeDeclaration(it, CodeNode(typeDeclaration, isProjectContent))
}
}

fun PsiElement.foldTextOfLevel(foldingLevel: Int = 1): String {
val file = this.containingFile
val document = file.viewProvider.document ?: return text
val fileNode = file.node ?: return text

val foldingBuilder = LanguageFolding.INSTANCE.forLanguage(this.language) ?: return text
var descriptors: List<FoldingDescriptor> = listOf()
var timeTaken = measureTimeMillis {
descriptors = foldingBuilder.buildFoldRegions(fileNode, document)
.filter {
textRange.contains(it.range)
// && it.element.textRange.startOffset > textRange.startOffset // Exclude the function itself
}
.sortedBy { it.range.startOffset }
.let {
findDescriptorsOfFoldingLevel(it, foldingLevel)
}
}
Log.info("=============> [$this] Time taken to build fold regions: $timeTaken ms, ${file.virtualFile.path}")
var result = ""
timeTaken = measureTimeMillis {
result = foldTextByDescriptors(descriptors)
}
Log.info("=============> [$this] Time taken to fold text: $timeTaken ms, ${file.virtualFile.path}")
return result
}

private fun findDescriptorsOfFoldingLevel(
descriptors: List<FoldingDescriptor>,
foldingLevel: Int
): List<FoldingDescriptor> {
val nestedDescriptors = mutableListOf<FoldingDescriptor>()
val stack = mutableListOf<FoldingDescriptor>()

for (descriptor in descriptors.sortedBy { it.range.startOffset }) {
while (stack.isNotEmpty() && !stack.last().range.contains(descriptor.range)) {
stack.removeAt(stack.size - 1)
}
stack.add(descriptor)
if (stack.size == foldingLevel) {
nestedDescriptors.add(descriptor)
}
}

return nestedDescriptors
}

private fun PsiElement.foldTextByDescriptors(descriptors: List<FoldingDescriptor>): String {
val sortedDescriptors = descriptors.sortedBy { it.range.startOffset }
val builder = StringBuilder()
var currentIndex = 0

for (descriptor in sortedDescriptors) {
val range = descriptor.range.shiftRight(-startOffset)
if (range.startOffset >= currentIndex) {
builder.append(text, currentIndex, range.startOffset)
builder.append(descriptor.placeholderText)
currentIndex = range.endOffset
}
}
builder.append(text.substring(currentIndex))

return builder.toString()
}
}
26 changes: 1 addition & 25 deletions src/main/kotlin/ai/devchat/plugin/IDEServer.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.devchat.plugin

import ai.devchat.common.IDEUtils.runInEdtAndGet
import ai.devchat.common.Log
import ai.devchat.common.Notifier
import ai.devchat.common.PathUtils
Expand Down Expand Up @@ -47,8 +48,6 @@ import kotlinx.serialization.Serializable
import java.awt.Point
import java.io.File
import java.net.ServerSocket
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CountDownLatch
import kotlin.reflect.full.memberFunctions


Expand Down Expand Up @@ -446,29 +445,6 @@ fun Editor.diffWith(newText: String, autoEdit: Boolean) {
}
}

fun <T> runInEdtAndGet(block: () -> T): T {
val app = ApplicationManager.getApplication()
if (app.isDispatchThread) {
return block()
}
val future = CompletableFuture<T>()
val latch = CountDownLatch(1)
app.invokeLater {
try {
val result = block()
future.complete(result)
} catch (e: Exception) {
future.completeExceptionally(e)
} finally {
latch.countDown()
}
}
latch.await()
return future.get()
}



fun Project.getPsiFile(filePath: String): PsiFile = runInEdtAndGet {
ReadAction.compute<PsiFile, Throwable> {
val virtualFile = LocalFileSystem.getInstance().findFileByIoFile(File(filePath))
Expand Down
14 changes: 7 additions & 7 deletions src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ai.devchat.storage.CONFIG
import com.google.gson.Gson
import com.google.gson.annotations.SerializedName
import com.intellij.openapi.diagnostic.Logger
import com.intellij.psi.PsiFile
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch
Expand Down Expand Up @@ -47,9 +48,8 @@ class Agent(val scope: CoroutineScope) {
}

data class CompletionRequest(
val filepath: String,
val file: PsiFile,
val language: String,
val text: String,
val position: Int,
val manually: Boolean?,
)
Expand Down Expand Up @@ -100,8 +100,9 @@ class Agent(val scope: CoroutineScope) {
) {
companion object {
fun fromCompletionRequest(completionRequest: CompletionRequest): RequestInfo {
val upperPart = completionRequest.text.substring(0, completionRequest.position)
val lowerPart = completionRequest.text.substring(completionRequest.position)
val fileContent = completionRequest.file.text
val upperPart = fileContent.substring(0, completionRequest.position)
val lowerPart = fileContent.substring(completionRequest.position)
val currentLinePrefix = upperPart.substringAfterLast(LINE_SEPARATOR, upperPart)
val currentLineSuffix = lowerPart.lineSequence().firstOrNull()?.second ?: ""
val currentIndent = currentLinePrefix.takeWhile { it.isWhitespace() }.length
Expand All @@ -112,7 +113,7 @@ class Agent(val scope: CoroutineScope) {
i > 0 && v.second.trim().isNotEmpty()
}?.value?.second
return RequestInfo(
filepath = completionRequest.filepath,
filepath = completionRequest.file.virtualFile.path,
language = completionRequest.language,
upperPart = upperPart,
lowerPart = lowerPart,
Expand Down Expand Up @@ -277,8 +278,7 @@ class Agent(val scope: CoroutineScope) {
val model = CONFIG["complete_model"] as? String
var startTime = System.currentTimeMillis()
val prompt = ContextBuilder(
completionRequest.filepath,
completionRequest.text,
completionRequest.file,
completionRequest.position
).createPrompt(model)
val promptBuildingElapse = System.currentTimeMillis() - startTime
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package ai.devchat.plugin.completion.agent

import ai.devchat.storage.CONFIG
import com.intellij.lang.Language
import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ReadAction
Expand All @@ -9,7 +8,8 @@ import com.intellij.openapi.editor.Editor
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiFile
import io.ktor.util.*
import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers

@Service
class AgentService : Disposable {
Expand All @@ -24,9 +24,8 @@ class AgentService : Disposable {
}?.let { file ->
agent.provideCompletions(
Agent.CompletionRequest(
file.virtualFile.path,
file,
file.getLanguageId(),
editor.document.text,
offset,
manually,
)
Expand Down
Loading

0 comments on commit 244143e

Please sign in to comment.