diff --git a/bun/lsp.ts b/bun/lsp.ts index f881aef..dae8421 100644 --- a/bun/lsp.ts +++ b/bun/lsp.ts @@ -1,5 +1,6 @@ import { type Nvim } from "bunvim"; import type { NvimBuffer } from "./nvim/buffer.ts"; +import type { PositionString } from "./nvim/window.ts"; export class Lsp { private requestCounter = 0; @@ -21,8 +22,7 @@ export class Lsp { requestHover( buffer: NvimBuffer, - row: number, - col: number, + pos: PositionString, ): Promise { return new Promise((resolve, reject) => { const requestId = this.getRequestId(); @@ -37,8 +37,8 @@ export class Lsp { uri = vim.uri_from_bufnr(${buffer.id}) }, position = { - line = ${row}, - character = ${col} + line = ${pos.row}, + character = ${pos.col} } }, function(responses) require('magenta').lsp_response("${requestId}", responses) @@ -54,8 +54,7 @@ export class Lsp { requestReferences( buffer: NvimBuffer, - row: number, - col: number, + pos: PositionString, ): Promise { return new Promise((resolve, reject) => { const requestId = this.getRequestId(); @@ -70,8 +69,8 @@ export class Lsp { uri = vim.uri_from_bufnr(${buffer.id}) }, position = { - line = ${row}, - character = ${col} + line = ${pos.row}, + character = ${pos.col} }, context = { includeDeclaration = true diff --git a/bun/nvim/window.ts b/bun/nvim/window.ts index 884c0f2..0c65581 100644 --- a/bun/nvim/window.ts +++ b/bun/nvim/window.ts @@ -4,6 +4,16 @@ import { NvimBuffer, type BufNr } from "./buffer.ts"; export type Row0Indexed = number & { __row0Indexed: true }; export type Row1Indexed = number & { __row1Indexed: true }; export type ByteIdx = number & { __byteIdx: true }; + +/** A coordinate in a js string, which are utf-16 encoded by default. This is the coordinate that lsp clients typically expect. + */ +export type StringIdx = number & { __charIdx: true }; + +export type PositionString = { + row: Row0Indexed; + col: StringIdx; +}; + export type Position1Indexed = { row: Row1Indexed; col: ByteIdx; diff --git a/bun/tea/util.ts b/bun/tea/util.ts index c3b169b..83db326 100644 --- a/bun/tea/util.ts +++ b/bun/tea/util.ts @@ -1,6 +1,11 @@ import type { Nvim } from "bunvim"; import type { Line, NvimBuffer } from "../nvim/buffer.ts"; -import type { ByteIdx, Position0Indexed } from "../nvim/window.ts"; +import type { + PositionString, + ByteIdx, + Position0Indexed, + StringIdx, +} from "../nvim/window.ts"; export async function replaceBetweenPositions({ buffer, @@ -37,10 +42,10 @@ export async function replaceBetweenPositions({ export function calculatePosition( startPos: Position0Indexed, buf: Buffer, - indexInText: number, + indexInText: ByteIdx, ): Position0Indexed { let { row, col } = startPos; - let currentIndex = 0; + let currentIndex: ByteIdx = 0 as ByteIdx; while (currentIndex < indexInText) { // 10 == '\n' in hex @@ -56,6 +61,28 @@ export function calculatePosition( return { row, col }; } +export function calculateStringPosition( + startPos: PositionString, + content: string, + indexInText: StringIdx, +): PositionString { + let { row, col } = startPos; + let currentIndex = 0 as StringIdx; + + while (currentIndex < indexInText) { + // 10 == '\n' in hex + if (content[currentIndex] == "\n") { + row++; + col = 0 as StringIdx; + } else { + col++; + } + currentIndex++; + } + + return { row, col }; +} + export async function logBuffer(buffer: NvimBuffer, context: { nvim: Nvim }) { const lines = await buffer.getLines({ start: 0, diff --git a/bun/test/fixtures/test.ts b/bun/test/fixtures/test.ts new file mode 100644 index 0000000..34cfbd3 --- /dev/null +++ b/bun/test/fixtures/test.ts @@ -0,0 +1,17 @@ +type Nested = { + a: { + b: { + c: "test"; + }; + }; +}; + +const val: Nested = { + a: { + b: { + c: "test", + }, + }, +}; + +console.log(val.a.b.c); diff --git a/bun/tools/findReferences.spec.ts b/bun/tools/findReferences.spec.ts new file mode 100644 index 0000000..640c67d --- /dev/null +++ b/bun/tools/findReferences.spec.ts @@ -0,0 +1,66 @@ +import { type ToolRequestId } from "./toolManager.ts"; +import { describe, it, expect } from "bun:test"; +import { withDriver } from "../test/preamble"; +import { pollUntil } from "../utils/async.ts"; + +describe("bun/tools/findReferences.spec.ts", () => { + it.only("findReferences end-to-end", async () => { + await withDriver(async (driver) => { + await driver.editFile("bun/test/fixtures/test.ts"); + await driver.showSidebar(); + + await driver.inputMagentaText(`Try finding references for a symbol`); + await driver.send(); + + const toolRequestId = "id" as ToolRequestId; + await driver.mockAnthropic.respond({ + stopReason: "tool_use", + text: "ok, here goes", + toolRequests: [ + { + status: "ok", + value: { + type: "tool_use", + id: toolRequestId, + name: "find_references", + input: { + filePath: "bun/test/fixtures/test.ts", + symbol: "val.a.b.c", + }, + }, + }, + ], + }); + + const result = await pollUntil( + () => { + const state = driver.magenta.chatApp.getState(); + if (state.status != "running") { + throw new Error(`app crashed`); + } + + const toolWrapper = + state.model.toolManager.toolWrappers[toolRequestId]; + if (!toolWrapper) { + throw new Error( + `could not find toolWrapper with id ${toolRequestId}`, + ); + } + + if (toolWrapper.model.state.state != "done") { + throw new Error(`Request not done`); + } + + return toolWrapper.model.state.result; + }, + { timeout: 3000 }, + ); + + expect(result).toEqual({ + tool_use_id: toolRequestId, + type: "tool_result", + content: `bun/test/fixtures/test.ts:4:6\nbun/test/fixtures/test.ts:12:6\nbun/test/fixtures/test.ts:17:20\n`, + }); + }); + }); +}); diff --git a/bun/tools/findReferences.ts b/bun/tools/findReferences.ts index 98d4226..61e39ac 100644 --- a/bun/tools/findReferences.ts +++ b/bun/tools/findReferences.ts @@ -7,11 +7,15 @@ import { assertUnreachable } from "../utils/assertUnreachable.ts"; import { getOrOpenBuffer } from "../utils/buffers.ts"; import type { NvimBuffer } from "../nvim/buffer.ts"; import type { Nvim } from "bunvim"; -import type { Lsp } from "../lsp.ts"; +import path from "path"; +import { getcwd } from "../nvim/nvim.ts"; +import type { Lsp } from "../lsp.ts";import { getcwd } from "../nvim/nvim.ts"; +import path from "path"; +import { calculateStringPosition } from "../tea/util.ts"; +import type { PositionString, StringIdx } from "../nvim/window.ts"; export type Model = { type: "find_references"; - autoRespond: boolean; request: ReferencesToolUseRequest; state: | { @@ -51,7 +55,6 @@ export function initModel( ): [Model, Thunk] { const model: Model = { type: "find_references", - autoRespond: true, request, state: { state: "processing", @@ -62,7 +65,6 @@ export function initModel( async (dispatch) => { const { lsp, nvim } = context; const filePath = model.request.input.filePath; - context.nvim.logger?.debug(`request: ${JSON.stringify(model.request)}`); const bufferResult = await getOrOpenBuffer({ relativePath: filePath, context: { nvim }, @@ -85,31 +87,11 @@ export function initModel( }); return; } + const symbolStart = bufferContent.indexOf( + model.request.input.symbol, + ) as StringIdx; - let searchText = bufferContent; - let startOffset = 0; - - // If context is provided, find it first - if (model.request.input.context) { - const contextIndex = bufferContent.indexOf(model.request.input.context); - if (contextIndex === -1) { - dispatch({ - type: "finish", - result: { - type: "tool_result", - tool_use_id: model.request.id, - content: `Context not found in file.`, - is_error: true, - }, - }); - return; - } - searchText = model.request.input.context; - startOffset = contextIndex; - } - - const symbolIndex = searchText.indexOf(model.request.input.symbol); - if (symbolIndex === -1) { + if (symbolStart === -1) { dispatch({ type: "finish", result: { @@ -122,22 +104,22 @@ export function initModel( return; } - const absoluteSymbolIndex = startOffset + symbolIndex; - const precedingText = bufferContent.substring(0, absoluteSymbolIndex); - const row = precedingText.split("\n").length - 1; - const lastNewline = precedingText.lastIndexOf("\n"); - const col = - lastNewline === -1 - ? absoluteSymbolIndex - : absoluteSymbolIndex - lastNewline - 1; + const symbolPos = calculateStringPosition( + { row: 0, col: 0 } as PositionString, + bufferContent, + (symbolStart + model.request.input.symbol.length - 1) as StringIdx, + ); try { - const result = await lsp.requestReferences(buffer, row, col); + const cwd = await getcwd(nvim); + const result = await lsp.requestReferences(buffer, symbolPos); let content = ""; for (const lspResult of result) { if (lspResult != null && lspResult.result) { for (const ref of lspResult.result) { - content += `${ref.uri}:${ref.range.start.line + 1}:${ref.range.start.character}\n`; + const uri = ref.uri.startsWith('file://') ? ref.uri.slice(7) : ref.uri; + const relativePath = path.relative(cwd, uri); + content += `${relativePath}:${ref.range.start.line + 1}:${ref.range.start.character}\n`; } } } @@ -204,14 +186,9 @@ export const spec: Anthropic.Anthropic.Tool = { }, symbol: { type: "string", - description: - "The symbol to find references for. We will use the first occurrence of the symbol.", - }, - context: { - type: "string", - description: `Optionally, you can disambiguate which instance of the symbol you want to find references for. \ -If context is provided, we will first find the first instance of context in the file, and then look for the symbol inside the context. \ -This should be the literal text of the file. Regular expressions are not allowed.`, + description: `The symbol to find references for. +We will use the first occurrence of the symbol. +We will use the right-most character of this string, so if the string is "a.b.c", we will find references for c.`, }, }, required: ["filePath", "symbol"], @@ -224,7 +201,6 @@ export type ReferencesToolUseRequest = { input: { filePath: string; symbol: string; - context?: string; }; name: "find_references"; }; @@ -271,13 +247,6 @@ export function validateToolRequest( return { status: "error", error: "expected input.symbol to be a string" }; } - if (input.context && typeof input.context != "string") { - return { - status: "error", - error: "input.context must be a string if provided", - }; - } - return { status: "ok", value: req as ReferencesToolUseRequest, diff --git a/bun/tools/hover.spec.ts b/bun/tools/hover.spec.ts new file mode 100644 index 0000000..ae37666 --- /dev/null +++ b/bun/tools/hover.spec.ts @@ -0,0 +1,66 @@ +import { type ToolRequestId } from "./toolManager.ts"; +import { describe, it, expect } from "bun:test"; +import { withDriver } from "../test/preamble"; +import { pollUntil } from "../utils/async.ts"; + +describe("bun/tools/hover.spec.ts", () => { + it("hover end-to-end", async () => { + await withDriver(async (driver) => { + await driver.editFile("bun/test/fixtures/test.ts"); + await driver.showSidebar(); + + await driver.inputMagentaText(`Try hovering a symbol`); + await driver.send(); + + const toolRequestId = "id" as ToolRequestId; + await driver.mockAnthropic.respond({ + stopReason: "tool_use", + text: "ok, here goes", + toolRequests: [ + { + status: "ok", + value: { + type: "tool_use", + id: toolRequestId, + name: "hover", + input: { + filePath: "bun/test/fixtures/test.ts", + symbol: "val.a.b.c", + }, + }, + }, + ], + }); + + const result = await pollUntil( + () => { + const state = driver.magenta.chatApp.getState(); + if (state.status != "running") { + throw new Error(`app crashed`); + } + + const toolWrapper = + state.model.toolManager.toolWrappers[toolRequestId]; + if (!toolWrapper) { + throw new Error( + `could not find toolWrapper with id ${toolRequestId}`, + ); + } + + if (toolWrapper.model.state.state != "done") { + throw new Error(`Request not done`); + } + + return toolWrapper.model.state.result; + }, + { timeout: 3000 }, + ); + + expect(result).toEqual({ + tool_use_id: toolRequestId, + type: "tool_result", + content: `(markdown):\n\n\`\`\`typescript\n(property) c: \"test\"\n\`\`\`\n\n`, + }); + }); + }); +}); diff --git a/bun/tools/hover.ts b/bun/tools/hover.ts index 4492f4b..7f90e01 100644 --- a/bun/tools/hover.ts +++ b/bun/tools/hover.ts @@ -8,6 +8,8 @@ import { getOrOpenBuffer } from "../utils/buffers.ts"; import type { NvimBuffer } from "../nvim/buffer.ts"; import type { Nvim } from "bunvim"; import type { Lsp } from "../lsp.ts"; +import { calculateStringPosition } from "../tea/util.ts"; +import type { PositionString, StringIdx } from "../nvim/window.ts"; export type Model = { type: "hover"; @@ -63,7 +65,6 @@ export function initModel( async (dispatch) => { const { lsp } = context; const filePath = model.request.input.filePath; - context.nvim.logger?.debug(`request: ${JSON.stringify(model.request)}`); const bufferResult = await getOrOpenBuffer({ relativePath: filePath, context, @@ -87,30 +88,10 @@ export function initModel( return; } - let searchText = bufferContent; - let startOffset = 0; - - // If context is provided, find it first - if (model.request.input.context) { - const contextIndex = bufferContent.indexOf(model.request.input.context); - if (contextIndex === -1) { - dispatch({ - type: "finish", - result: { - type: "tool_result", - tool_use_id: model.request.id, - content: `Context not found in file.`, - is_error: true, - }, - }); - return; - } - searchText = model.request.input.context; - startOffset = contextIndex; - } - - const symbolIndex = searchText.indexOf(model.request.input.symbol); - if (symbolIndex === -1) { + const symbolStart = bufferContent.indexOf( + model.request.input.symbol, + ) as StringIdx; + if (symbolStart === -1) { dispatch({ type: "finish", result: { @@ -123,17 +104,14 @@ export function initModel( return; } - const absoluteSymbolIndex = startOffset + symbolIndex; - const precedingText = bufferContent.substring(0, absoluteSymbolIndex); - const row = precedingText.split("\n").length - 1; - const lastNewline = precedingText.lastIndexOf("\n"); - const col = - lastNewline === -1 - ? absoluteSymbolIndex - : absoluteSymbolIndex - lastNewline - 1; + const symbolPos = calculateStringPosition( + { row: 0, col: 0 } as PositionString, + bufferContent, + (symbolStart + model.request.input.symbol.length - 1) as StringIdx, + ); try { - const result = await lsp.requestHover(buffer, row, col); + const result = await lsp.requestHover(buffer, symbolPos); let content = ""; for (const lspResult of result) { if (lspResult != null) { @@ -207,14 +185,9 @@ export const spec: Anthropic.Anthropic.Tool = { }, symbol: { type: "string", - description: - "The symbol to get hover information for. We will use the first occurrence of the symbol.", - }, - context: { - type: "string", - description: `Optionally, you can disambiguate which instance of the symbol you want to hover. \ -If context is provided, we will first find the first instance of context in the file, and then look for the symbol inside the context. \ -This should be the literal text of the file. Regular expressions are not allowed.`, + description: `The symbol to get hover information for. +We will use the first occurrence of the symbol. +We will use the right-most character of this string, so if the string is "a.b.c", we will hover c.`, }, }, required: ["filePath", "symbol"], @@ -227,7 +200,6 @@ export type HoverToolUseRequest = { input: { filePath: string; symbol: string; - context?: string; }; name: "hover"; }; @@ -269,13 +241,6 @@ export function validateToolRequest(req: unknown): Result { return { status: "error", error: "expected input.symbol to be a string" }; } - if (input.context && typeof input.context != "string") { - return { - status: "error", - error: "input.context must be a string if provided", - }; - } - return { status: "ok", value: req as HoverToolUseRequest, diff --git a/minimal-init.lua b/minimal-init.lua index 3c1ace7..0addc3d 100644 --- a/minimal-init.lua +++ b/minimal-init.lua @@ -1,2 +1,19 @@ vim.opt.runtimepath:append(".") require("magenta") + +vim.api.nvim_create_autocmd( + "FileType", + { + -- This handler will fire when the buffer's 'filetype' is "python" + pattern = "typescript", + callback = function(ev) + vim.lsp.start( + { + name = "ts_ls", + cmd = {"typescript-language-server", "--stdio"}, + root_dir = vim.fs.root(ev.buf, {"tsconfig.json", "package.json"}) + } + ) + end + } +)