Skip to content

Commit

Permalink
improve findReferences and hover tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dlants committed Dec 30, 2024
1 parent df66bbf commit d477794
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 115 deletions.
15 changes: 7 additions & 8 deletions bun/lsp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { type Nvim } from "bunvim";
import type { NvimBuffer } from "./nvim/buffer.ts";
import type { PositionString } from "./nvim/window.ts";

export class Lsp {
private requestCounter = 0;
Expand All @@ -21,8 +22,7 @@ export class Lsp {

requestHover(
buffer: NvimBuffer,
row: number,
col: number,
pos: PositionString,
): Promise<LspHoverResponse> {
return new Promise<LspHoverResponse>((resolve, reject) => {
const requestId = this.getRequestId();
Expand All @@ -37,8 +37,8 @@ export class Lsp {
uri = vim.uri_from_bufnr(${buffer.id})
},
position = {
line = ${row},
character = ${col}
line = ${pos.row},
character = ${pos.col}
}
}, function(responses)
require('magenta').lsp_response("${requestId}", responses)
Expand All @@ -54,8 +54,7 @@ export class Lsp {

requestReferences(
buffer: NvimBuffer,
row: number,
col: number,
pos: PositionString,
): Promise<LspReferencesResponse> {
return new Promise((resolve, reject) => {
const requestId = this.getRequestId();
Expand All @@ -70,8 +69,8 @@ export class Lsp {
uri = vim.uri_from_bufnr(${buffer.id})
},
position = {
line = ${row},
character = ${col}
line = ${pos.row},
character = ${pos.col}
},
context = {
includeDeclaration = true
Expand Down
10 changes: 10 additions & 0 deletions bun/nvim/window.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ import { NvimBuffer, type BufNr } from "./buffer.ts";
export type Row0Indexed = number & { __row0Indexed: true };
export type Row1Indexed = number & { __row1Indexed: true };
export type ByteIdx = number & { __byteIdx: true };

/** A coordinate in a js string, which are utf-16 encoded by default. This is the coordinate that lsp clients typically expect.
*/
export type StringIdx = number & { __charIdx: true };

export type PositionString = {
row: Row0Indexed;
col: StringIdx;
};

export type Position1Indexed = {
row: Row1Indexed;
col: ByteIdx;
Expand Down
33 changes: 30 additions & 3 deletions bun/tea/util.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import type { Nvim } from "bunvim";
import type { Line, NvimBuffer } from "../nvim/buffer.ts";
import type { ByteIdx, Position0Indexed } from "../nvim/window.ts";
import type {
PositionString,
ByteIdx,
Position0Indexed,
StringIdx,
} from "../nvim/window.ts";

export async function replaceBetweenPositions({
buffer,
Expand Down Expand Up @@ -37,10 +42,10 @@ export async function replaceBetweenPositions({
export function calculatePosition(
startPos: Position0Indexed,
buf: Buffer,
indexInText: number,
indexInText: ByteIdx,
): Position0Indexed {
let { row, col } = startPos;
let currentIndex = 0;
let currentIndex: ByteIdx = 0 as ByteIdx;

while (currentIndex < indexInText) {
// 10 == '\n' in hex
Expand All @@ -56,6 +61,28 @@ export function calculatePosition(
return { row, col };
}

export function calculateStringPosition(
startPos: PositionString,
content: string,
indexInText: StringIdx,
): PositionString {
let { row, col } = startPos;
let currentIndex = 0 as StringIdx;

while (currentIndex < indexInText) {
// 10 == '\n' in hex
if (content[currentIndex] == "\n") {
row++;
col = 0 as StringIdx;
} else {
col++;
}
currentIndex++;
}

return { row, col };
}

export async function logBuffer(buffer: NvimBuffer, context: { nvim: Nvim }) {
const lines = await buffer.getLines({
start: 0,
Expand Down
17 changes: 17 additions & 0 deletions bun/test/fixtures/test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
type Nested = {
a: {
b: {
c: "test";
};
};
};

const val: Nested = {
a: {
b: {
c: "test",
},
},
};

console.log(val.a.b.c);
66 changes: 66 additions & 0 deletions bun/tools/findReferences.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import { type ToolRequestId } from "./toolManager.ts";
import { describe, it, expect } from "bun:test";
import { withDriver } from "../test/preamble";
import { pollUntil } from "../utils/async.ts";

describe("bun/tools/findReferences.spec.ts", () => {
it.only("findReferences end-to-end", async () => {
await withDriver(async (driver) => {
await driver.editFile("bun/test/fixtures/test.ts");
await driver.showSidebar();

await driver.inputMagentaText(`Try finding references for a symbol`);
await driver.send();

const toolRequestId = "id" as ToolRequestId;
await driver.mockAnthropic.respond({
stopReason: "tool_use",
text: "ok, here goes",
toolRequests: [
{
status: "ok",
value: {
type: "tool_use",
id: toolRequestId,
name: "find_references",
input: {
filePath: "bun/test/fixtures/test.ts",
symbol: "val.a.b.c",
},
},
},
],
});

const result = await pollUntil(
() => {
const state = driver.magenta.chatApp.getState();
if (state.status != "running") {
throw new Error(`app crashed`);
}

const toolWrapper =
state.model.toolManager.toolWrappers[toolRequestId];
if (!toolWrapper) {
throw new Error(
`could not find toolWrapper with id ${toolRequestId}`,
);
}

if (toolWrapper.model.state.state != "done") {
throw new Error(`Request not done`);

Check failure on line 51 in bun/tools/findReferences.spec.ts

View workflow job for this annotation

GitHub Actions / test

error: Request not done

at /home/runner/work/magenta.nvim/magenta.nvim/bun/tools/findReferences.spec.ts:51:19 at /home/runner/work/magenta.nvim/magenta.nvim/bun/utils/async.ts:60:19
}

return toolWrapper.model.state.result;
},
{ timeout: 3000 },
);

expect(result).toEqual({
tool_use_id: toolRequestId,
type: "tool_result",
content: `bun/test/fixtures/test.ts:4:6\nbun/test/fixtures/test.ts:12:6\nbun/test/fixtures/test.ts:17:20\n`,
});
});
});
});
77 changes: 23 additions & 54 deletions bun/tools/findReferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ import { assertUnreachable } from "../utils/assertUnreachable.ts";
import { getOrOpenBuffer } from "../utils/buffers.ts";
import type { NvimBuffer } from "../nvim/buffer.ts";
import type { Nvim } from "bunvim";
import type { Lsp } from "../lsp.ts";
import path from "path";
import { getcwd } from "../nvim/nvim.ts";
import type { Lsp } from "../lsp.ts";import { getcwd } from "../nvim/nvim.ts";
import path from "path";
import { calculateStringPosition } from "../tea/util.ts";
import type { PositionString, StringIdx } from "../nvim/window.ts";

export type Model = {
type: "find_references";
autoRespond: boolean;
request: ReferencesToolUseRequest;
state:
| {
Expand Down Expand Up @@ -51,7 +55,6 @@ export function initModel(
): [Model, Thunk<Msg>] {
const model: Model = {
type: "find_references",
autoRespond: true,
request,
state: {
state: "processing",
Expand All @@ -62,7 +65,6 @@ export function initModel(
async (dispatch) => {
const { lsp, nvim } = context;
const filePath = model.request.input.filePath;
context.nvim.logger?.debug(`request: ${JSON.stringify(model.request)}`);
const bufferResult = await getOrOpenBuffer({
relativePath: filePath,
context: { nvim },
Expand All @@ -85,31 +87,11 @@ export function initModel(
});
return;
}
const symbolStart = bufferContent.indexOf(
model.request.input.symbol,
) as StringIdx;

let searchText = bufferContent;
let startOffset = 0;

// If context is provided, find it first
if (model.request.input.context) {
const contextIndex = bufferContent.indexOf(model.request.input.context);
if (contextIndex === -1) {
dispatch({
type: "finish",
result: {
type: "tool_result",
tool_use_id: model.request.id,
content: `Context not found in file.`,
is_error: true,
},
});
return;
}
searchText = model.request.input.context;
startOffset = contextIndex;
}

const symbolIndex = searchText.indexOf(model.request.input.symbol);
if (symbolIndex === -1) {
if (symbolStart === -1) {
dispatch({
type: "finish",
result: {
Expand All @@ -122,22 +104,22 @@ export function initModel(
return;
}

const absoluteSymbolIndex = startOffset + symbolIndex;
const precedingText = bufferContent.substring(0, absoluteSymbolIndex);
const row = precedingText.split("\n").length - 1;
const lastNewline = precedingText.lastIndexOf("\n");
const col =
lastNewline === -1
? absoluteSymbolIndex
: absoluteSymbolIndex - lastNewline - 1;
const symbolPos = calculateStringPosition(
{ row: 0, col: 0 } as PositionString,
bufferContent,
(symbolStart + model.request.input.symbol.length - 1) as StringIdx,
);

try {
const result = await lsp.requestReferences(buffer, row, col);
const cwd = await getcwd(nvim);
const result = await lsp.requestReferences(buffer, symbolPos);
let content = "";
for (const lspResult of result) {
if (lspResult != null && lspResult.result) {
for (const ref of lspResult.result) {
content += `${ref.uri}:${ref.range.start.line + 1}:${ref.range.start.character}\n`;
const uri = ref.uri.startsWith('file://') ? ref.uri.slice(7) : ref.uri;
const relativePath = path.relative(cwd, uri);
content += `${relativePath}:${ref.range.start.line + 1}:${ref.range.start.character}\n`;
}
}
}
Expand Down Expand Up @@ -204,14 +186,9 @@ export const spec: Anthropic.Anthropic.Tool = {
},
symbol: {
type: "string",
description:
"The symbol to find references for. We will use the first occurrence of the symbol.",
},
context: {
type: "string",
description: `Optionally, you can disambiguate which instance of the symbol you want to find references for. \
If context is provided, we will first find the first instance of context in the file, and then look for the symbol inside the context. \
This should be the literal text of the file. Regular expressions are not allowed.`,
description: `The symbol to find references for.
We will use the first occurrence of the symbol.
We will use the right-most character of this string, so if the string is "a.b.c", we will find references for c.`,
},
},
required: ["filePath", "symbol"],
Expand All @@ -224,7 +201,6 @@ export type ReferencesToolUseRequest = {
input: {
filePath: string;
symbol: string;
context?: string;
};
name: "find_references";
};
Expand Down Expand Up @@ -271,13 +247,6 @@ export function validateToolRequest(
return { status: "error", error: "expected input.symbol to be a string" };
}

if (input.context && typeof input.context != "string") {
return {
status: "error",
error: "input.context must be a string if provided",
};
}

return {
status: "ok",
value: req as ReferencesToolUseRequest,
Expand Down
Loading

0 comments on commit d477794

Please sign in to comment.