diff --git a/lua/dial/augend/common.lua b/lua/dial/augend/common.lua index a499c32..37d5fd8 100644 --- a/lua/dial/augend/common.lua +++ b/lua/dial/augend/common.lua @@ -6,12 +6,14 @@ local M = {} ---augend の find field を簡単に実装する。 ---@param ptn string +---@param allow_match_before_cursor? boolean ---@return findf -function M.find_pattern(ptn) +function M.find_pattern(ptn, allow_match_before_cursor) ---@param line string ---@param cursor? integer ---@return textrange? local function f(line, cursor) + local match_before_cursor = nil local idx_start = 1 while idx_start <= #line do local s, e = line:find(ptn, idx_start) @@ -21,6 +23,7 @@ function M.find_pattern(ptn) -- cursor が終了文字より後ろにあったら終了 return { from = s, to = e } else + match_before_cursor = { from = s, to = e } -- 終了文字の後ろから探し始める idx_start = e + 1 end @@ -29,6 +32,9 @@ function M.find_pattern(ptn) break end end + if allow_match_before_cursor then + return match_before_cursor + end return nil end return f @@ -36,12 +42,14 @@ end -- augend の find field を簡単に実装する。 ---@param ptn string +---@param allow_match_before_cursor? boolean ---@return findf -function M.find_pattern_regex(ptn) +function M.find_pattern_regex(ptn, allow_match_before_cursor) ---@param line string ---@param cursor? integer ---@return textrange? local function f(line, cursor) + local match_before_cursor = nil local idx_start = 1 while idx_start <= #line do local s, e = vim.regex(ptn):match_str(line:sub(idx_start)) @@ -55,6 +63,7 @@ function M.find_pattern_regex(ptn) -- cursor が終了文字より後ろにあったら終了 return { from = s, to = e } else + match_before_cursor = { from = s, to = e } -- 終了文字の後ろから探し始める idx_start = e + 1 end @@ -63,6 +72,9 @@ function M.find_pattern_regex(ptn) break end end + if allow_match_before_cursor then + return match_before_cursor + end return nil end return f diff --git a/lua/dial/augend/constant.lua b/lua/dial/augend/constant.lua index 5b181ed..1aec1ec 100644 --- a/lua/dial/augend/constant.lua +++ b/lua/dial/augend/constant.lua @@ -1,7 +1,7 @@ local util = require "dial.util" local common = require "dial.augend.common" ----@alias AugendConstantConfig { elements: string[], cyclic: boolean, pattern_regexp: string, preserve_case: boolean } +---@alias AugendConstantConfig { elements: string[], cyclic: boolean, pattern_regexp: string, preserve_case: boolean, match_before_cursor: boolean } ---@class AugendConstant ---@implement Augend @@ -33,7 +33,7 @@ local function preserve_case(word) return nil end ----@param config { elements: string[], word?: boolean, cyclic?: boolean, pattern_regexp?: string, preserve_case?: boolean } +---@param config { elements: string[], word?: boolean, cyclic?: boolean, pattern_regexp?: string, preserve_case?: boolean, match_before_cursor?: boolean } ---@return Augend function M.new(config) util.validate_list("config.elements", config.elements, "string") @@ -43,10 +43,14 @@ function M.new(config) cyclic = { config.cyclic, "boolean", true }, pattern_regexp = { config.pattern_regexp, "string", true }, preserve_case = { config.preserve_case, "boolean", true }, + match_before_cursor = { config.match_before_cursor, "boolean", true }, } if config.preserve_case == nil then config.preserve_case = false end + if config.match_before_cursor == nil then + config.match_before_cursor = false + end if config.pattern_regexp == nil then local case_sensitive_flag = util.if_expr(config.preserve_case, [[\c]], [[\C]]) local word = util.unwrap_or(config.word, true) @@ -70,7 +74,7 @@ function AugendConstant:find(line, cursor) return vim.fn.escape(e, [[/\]]) end, self.config.elements) local vim_regex_ptn = self.config.pattern_regexp:format(table.concat(escaped_elements, [[\|]])) - return common.find_pattern_regex(vim_regex_ptn)(line, cursor) + return common.find_pattern_regex(vim_regex_ptn, self.config.match_before_cursor)(line, cursor) end ---@param text string diff --git a/tests/dial/augend/constant_spec.lua b/tests/dial/augend/constant_spec.lua index 6850513..c49fce6 100644 --- a/tests/dial/augend/constant_spec.lua +++ b/tests/dial/augend/constant_spec.lua @@ -2,4 +2,25 @@ local constant = require("dial.augend").constant describe("Test of constant between two words", function() local augend = constant.new { elements = { "true", "false" } } + + describe("find function", function() + it("can find a completely matching word", function() + assert.are.same(augend:find("enable = true", 1), { from = 10, to = 13 }) + assert.are.same(augend:find("enable = false", 1), { from = 10, to = 14 }) + end) + it("does not find a word including element words", function() + assert.are.same(augend:find("mistakenly construed", 1), nil) + end) + it("does not find a word before the cursor when match_before_cursor = false", function() + assert.are.same(augend:find("true negative", 5), nil) + end) + end) + + augend = constant.new { elements = { "true", "false" }, match_before_cursor = true } + + describe("find function", function() + it("does find a word before the cursor when match_before_cursor = true", function() + assert.are.same(augend:find("true positive", 5), { from = 1, to = 4 }) + end) + end) end)