diff --git a/lua/rzls/handlers/htmlformatting.lua b/lua/rzls/handlers/htmlformatting.lua new file mode 100644 index 0000000..8f08be7 --- /dev/null +++ b/lua/rzls/handlers/htmlformatting.lua @@ -0,0 +1,61 @@ +local documentstore = require("rzls.documentstore") +local format = require("rzls.utils.format") +local razor = require("rzls.razor") + +---@class rzls.htmlFormattingParams +---@field textDocument lsp.TextDocumentIdentifier +---@field _razor_hostDocumentVersion integer +---@field options lsp.FormattingOptions + +---@param err lsp.ResponseError +---@param result rzls.htmlFormattingParams +---@param _ctx lsp.HandlerContext +---@param _config table +return function(err, result, _ctx, _config) + if err then + vim.notify("Error in razor/htmlFormatting", vim.log.levels.ERROR) + return {}, nil + end + + local virtual_document = documentstore.get_virtual_document( + result.textDocument.uri, + result._razor_hostDocumentVersion, + razor.language_kinds.html + ) + assert(virtual_document, "Could not find html virtual document") + + local client = virtual_document:get_lsp_client() + if not client then + return {}, nil + end + + local lines = virtual_document:lines() + local line_count = #lines + local last_line = lines[line_count] + local range_formatting_response = client.request_sync("textDocument/rangeFormatting", { + textDocument = vim.lsp.util.make_text_document_params(virtual_document.buf), + range = { + start = { + line = 0, + character = 0, + }, + ["end"] = { + line = line_count - 1, + character = last_line:len(), + }, + }, + options = result.options, + }, nil, virtual_document.buf) + assert(range_formatting_response, "textDocument/rangeFormatting from virtual LSP return no error or result") + + if range_formatting_response.err ~= nil then + return nil, err + end + + local edits = {} + for _, html_edit in ipairs(range_formatting_response.result) do + vim.list_extend(edits, format.compute_minimal_edits(virtual_document:lines(), html_edit)) + end + + return { edits = edits } +end diff --git a/lua/rzls/handlers/init.lua b/lua/rzls/handlers/init.lua index b0c42cc..6123782 100644 --- a/lua/rzls/handlers/init.lua +++ b/lua/rzls/handlers/init.lua @@ -47,7 +47,7 @@ return { ["razor/provideSemanticTokensRange"] = require("rzls.handlers.providesemantictokensrange"), ["razor/foldingRange"] = not_implemented, - ["razor/htmlFormatting"] = not_implemented, + ["razor/htmlFormatting"] = require("rzls.handlers.htmlformatting"), ["razor/htmlOnTypeFormatting"] = not_implemented, ["razor/simplifyMethod"] = not_implemented, ["razor/formatNewFile"] = not_implemented, diff --git a/lua/rzls/utils/format.lua b/lua/rzls/utils/format.lua new file mode 100644 index 0000000..271efbf --- /dev/null +++ b/lua/rzls/utils/format.lua @@ -0,0 +1,91 @@ +local lcs = require("rzls.utils.lcs") +local M = {} + +---@param lines string[] +---@param range lsp.Range +local function extract_lines_from_range(lines, range) + local start_row = range.start.line + 1 + local start_col = range.start.character + 1 + local end_row = range["end"].line + 1 + local end_col = range["end"].character + 1 + + local source_lines = {} + -- Loop through the zero-indexed range [source_start_row, source_end_row) + for i = start_row, end_row do + local line = lines[i] + + if i == start_row then + line = line:sub(start_col, -1) + elseif i == end_row then + line = line:sub(1, end_col - 1) + end + + -- strip CR characters when neovim fails to identify the correct file format + if vim.endswith(line, "\r") then + table.insert(source_lines, line:sub(1, -2)) + else + table.insert(source_lines, line) + end + end + + return source_lines +end + +---@param source_buf string[] +---@param target_edit lsp.TextEdit +---@return lsp.TextEdit[] +function M.compute_minimal_edits(source_buf, target_edit) + local source_lines = extract_lines_from_range(source_buf, target_edit.range) + local target_lines = vim.split(target_edit.newText, "\r?\n") + + local source_text = table.concat(source_lines, "\n") + local target_text = table.concat(target_lines, "\n") + + local indices = vim.diff(source_text, target_text, { + algorithm = "histogram", + result_type = "indices", + }) + assert(type(indices) == "table") + + ---@type lsp.TextEdit[] + local edits = {} + + for _, idx in ipairs(indices) do + local source_line_start, source_line_count, target_line_start, target_line_count = unpack(idx) + local source_line_end = source_line_start + source_line_count - 1 + local target_line_end = target_line_start + target_line_count - 1 + + local source = table.concat(source_lines, "\n", source_line_start, source_line_end) + local target = table.concat(target_lines, "\n", target_line_start, target_line_end) + + local text_edits = lcs.to_lsp_edits( + lcs.diff(source, target), + source_line_start + target_edit.range.start.line - 1, + target_edit.range.start.character + ) + + vim.list_extend(edits, text_edits) + end + + local contains_non_whitespace_edit = vim.iter(edits):any(function(edit) + return edit.newText:find("%S") ~= nil + end) + + -- Diff the whole text if we encounter a non whitespace character in the edit. + -- This might happen when the formatted document deletes many lines + -- and `vim.diff` split those deletions into multiple hunks. + -- + -- This is rare but it might happen. + if contains_non_whitespace_edit then + vim.print("Performing slow formatting diff") + edits = lcs.to_lsp_edits( + lcs.diff(source_text, target_text), + target_edit.range.start.line, + target_edit.range.start.character + ) + end + + return edits +end + +return M diff --git a/lua/rzls/utils/lcs.lua b/lua/rzls/utils/lcs.lua new file mode 100644 index 0000000..69df6aa --- /dev/null +++ b/lua/rzls/utils/lcs.lua @@ -0,0 +1,173 @@ +local M = {} + +---@class rzls.lcs.Edit +---@field kind rzls.lcs.EditKind +---@field text string + +---@enum rzls.lcs.EditKind +M.edit_kind = { + addition = "addition", + removal = "removal", + unchanged = "unchanged", +} + +--- Computes the Long Common Subequence table. +--- Reference: [https://en.wikipedia.org/wiki/Longest_common_subsequence#Computing_the_length_of_the_LCS] +---@param source string +---@param target string +function M.generate_table(source, target) + local n = source:len() + 1 + local m = target:len() + 1 + + ---@type integer[][] + local lcs = {} + for i = 1, n do + lcs[i] = {} + for j = 1, m do + lcs[i][j] = 0 + end + end + + for i = 1, n do + for j = 1, m do + if i == 1 or j == 1 then + lcs[i][j] = 0 + elseif source:sub(i - 1, i - 1) == target:sub(j - 1, j - 1) then + lcs[i][j] = 1 + lcs[i - 1][j - 1] + else + lcs[i][j] = math.max(lcs[i - 1][j], lcs[i][j - 1]) + end + end + end + + return lcs +end + +---@generic T +---@param tbl T[] +---@return T[] +local function reverse_table(tbl) + local ret = {} + for i = #tbl, 1, -1 do + table.insert(ret, tbl[i]) + end + return ret +end + +--- Calculates a diff between two strings using LCS +---@param source string +---@param target string +---@return rzls.lcs.Edit[] +function M.diff(source, target) + local lcs = M.generate_table(source, target) + + local src_idx = source:len() + 1 + local trt_idx = target:len() + 1 + + ---@type rzls.lcs.Edit[] + local edits = {} + + while src_idx ~= 1 or trt_idx ~= 1 do + if src_idx == 1 then + table.insert(edits, { + kind = M.edit_kind.addition, + text = target:sub(trt_idx - 1, trt_idx - 1), + }) + trt_idx = trt_idx - 1 + elseif trt_idx == 1 then + table.insert(edits, { + kind = M.edit_kind.removal, + text = source:sub(src_idx - 1, src_idx - 1), + }) + src_idx = src_idx - 1 + elseif source:sub(src_idx - 1, src_idx - 1) == target:sub(trt_idx - 1, trt_idx - 1) then + table.insert(edits, { + kind = M.edit_kind.unchanged, + text = source:sub(src_idx - 1, src_idx - 1), + }) + src_idx = src_idx - 1 + trt_idx = trt_idx - 1 + elseif lcs[src_idx - 1][trt_idx] <= lcs[src_idx][trt_idx - 1] then + table.insert(edits, { + kind = M.edit_kind.addition, + text = target:sub(trt_idx - 1, trt_idx - 1), + }) + trt_idx = trt_idx - 1 + else + table.insert(edits, { + kind = M.edit_kind.removal, + text = source:sub(src_idx - 1, src_idx - 1), + }) + src_idx = src_idx - 1 + end + end + + return reverse_table(edits) +end + +---@param edits rzls.lcs.Edit[] +---@param line integer +---@param character integer +---@return lsp.TextEdit[] +function M.to_lsp_edits(edits, line, character) + local function advance_cursor(edit) + if edit.text == "\n" then + line = line + 1 + character = 0 + else + character = character + 1 + end + end + + ---@type lsp.TextEdit[] + local lsp_edits = {} + local i = 1 + while i < #edits do + -- Skip all unchanged edits and advance cursor + while i < #edits and edits[i].kind == M.edit_kind.unchanged do + advance_cursor(edits[i]) + i = i + 1 + end + + -- No more edits to compute + if i >= #edits then + break + end + + local new_text = "" + local start_line, start_character = line, character + + -- Collect consecutive additions and removals + while i < #edits and edits[i].kind ~= M.edit_kind.unchanged do + if edits[i].kind == M.edit_kind.addition then + new_text = new_text .. edits[i].text + elseif edits[i].kind == M.edit_kind.removal then + advance_cursor(edits[i]) + else + error("unexcepted edit kind " .. edits[i].kind) + end + i = i + 1 + end + + ---@type lsp.TextEdit + local lsp_edit = { + newText = new_text, + range = { + start = { + line = start_line, + character = start_character, + }, + ["end"] = { + line = line, + character = character, + }, + }, + } + + table.insert(lsp_edits, lsp_edit) + end + + return lsp_edits +end + +return M diff --git a/lua/rzls/virtual_document.lua b/lua/rzls/virtual_document.lua index 3dd3e21..847f608 100644 --- a/lua/rzls/virtual_document.lua +++ b/lua/rzls/virtual_document.lua @@ -59,4 +59,18 @@ function VirtualDocument:get_lsp_client() return vim.lsp.get_clients({ bufnr = self.buf, name = razor.lsp_names[self.kind] })[1] end +function VirtualDocument:line_count() + local lines = vim.split(self.content, "\r?\n", { trimempty = false }) + return #lines +end + +function VirtualDocument:lines() + return vim.split(self.content, "\r?\n", { trimempty = false }) +end + +function VirtualDocument:line_at(line) + local lines = vim.split(self.content, "\r?\n", { trimempty = false }) + return lines[line] +end + return VirtualDocument diff --git a/tests/rzls/utils/format_spec.lua b/tests/rzls/utils/format_spec.lua new file mode 100644 index 0000000..7e1d051 --- /dev/null +++ b/tests/rzls/utils/format_spec.lua @@ -0,0 +1,52 @@ +local format = require("rzls.utils.format") +---@diagnostic disable-next-line: undefined-field +local eq = assert.are.same + +---@return lsp.TextEdit +local function lsp_edit(new_text, start_line, start_char, end_line, end_char) + return { + newText = new_text, + range = { + start = { + line = start_line, + character = start_char, + }, + ["end"] = { + line = end_line, + character = end_char, + }, + }, + } +end + +describe("format", function() + it("computes minimal text edits for a buffer", function() + local source_text = [[ +foo + bar + baz +]] + local source_lines = vim.split(source_text, "\n") + + local targe_text = [[ +foo +bar +baz +]] + local target_lines = vim.split(targe_text, "\n") + + -- This edit replaces the full document + local full_replacement_edit = + lsp_edit(table.concat(target_lines, "\n"), 0, 0, #source_lines - 1, source_lines[#source_lines]:len()) + + local minimal_edits = format.compute_minimal_edits(source_lines, full_replacement_edit) + + -- Only contain edits that remove spaces of the source document to match the target + local expected = { + lsp_edit("", 1, 0, 1, 4), + lsp_edit("", 2, 0, 2, 8), + } + + eq(expected, minimal_edits) + end) +end) diff --git a/tests/rzls/utils/lcs_spec.lua b/tests/rzls/utils/lcs_spec.lua new file mode 100644 index 0000000..5e6293b --- /dev/null +++ b/tests/rzls/utils/lcs_spec.lua @@ -0,0 +1,129 @@ +local lcs = require("rzls.utils.lcs") +local kind = lcs.edit_kind +---@diagnostic disable-next-line: undefined-field +local eq = assert.are.same + +describe("lcs", function() + it("calculates diff for saturday -> sunday", function() + local edits = lcs.diff("sunday", "saturday") + + ---@type rzls.lcs.Edit[] + local expected = { + { text = "s", kind = kind.unchanged }, + { text = "a", kind = kind.addition }, + { text = "t", kind = kind.addition }, + { text = "u", kind = kind.unchanged }, + { text = "n", kind = kind.removal }, + { text = "r", kind = kind.addition }, + { text = "d", kind = kind.unchanged }, + { text = "a", kind = kind.unchanged }, + { text = "y", kind = kind.unchanged }, + } + eq(expected, edits) + end) + + ---@return lsp.TextEdit + local function lsp_edit(new_text, start_line, start_char, end_line, end_char) + return { + newText = new_text, + range = { + start = { + line = start_line, + character = start_char, + }, + ["end"] = { + line = end_line, + character = end_char, + }, + }, + } + end + + it("converts edits to lsp.TextEdit's", function() + local source = '' + local target = '
' + + local edits = lcs.diff(source, target) + local text_edits = lcs.to_lsp_edits(edits, 0, 0) + + local expected = { + -- Replaces "\n\n" with " " + lsp_edit(" ", 0, 4, 2, 0), + -- Replaces "foo" with "bar" + lsp_edit("bar", 2, 7, 2, 10), + } + + eq(expected, text_edits) + end) + + it("applies converted lsp.TextEdit's to buffer", function() + local source = '
' + local target = '' + + local edits = lcs.diff(source, target) + local text_edits = lcs.to_lsp_edits(edits, 0, 0) + + local buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(buf, 0, -1, true, vim.split(source, "\n")) + vim.lsp.util.apply_text_edits(text_edits, buf, "utf-8") + + local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, true) + eq(target, table.concat(lines, "\n")) + end) + + it("applies converted lsp.TextEdit's to buffer with CRLF line endings", function() + local source = '
' + local target = '' + + local edits = lcs.diff(source, target) + local text_edits = lcs.to_lsp_edits(edits, 0, 0) + + local buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(buf, 0, -1, true, vim.split(source, "\r\n")) + vim.lsp.util.apply_text_edits(text_edits, buf, "utf-8") + + local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, true) + eq(target, table.concat(lines, "\r\n")) + end) + + it("applies converted edits to buffer with multiple errors", function() + local source = [[ +
+

Intentional Leading Space

+ + +
+]] + local target = [[ +
+

Intentional Leading Space

+
+]] + + local edits = lcs.diff(source, target) + local text_edits = lcs.to_lsp_edits(edits, 0, 0) + + local buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(buf, 0, -1, true, vim.split(source, "\n")) + vim.lsp.util.apply_text_edits(text_edits, buf, "utf-8") + + local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, true) + eq(target, table.concat(lines, "\n")) + end) + + it("applies edits to unicode characters", function() + local source = "

💩

" + local target = "

:💩

" + + local edits = lcs.diff(source, target) + local text_edits = lcs.to_lsp_edits(edits, 0, 0) + + local buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(buf, 0, -1, true, vim.split(source, "\n")) + vim.lsp.util.apply_text_edits(text_edits, buf, "utf-16") + + local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, true) + eq(target, table.concat(lines, "\n")) + end) +end)