fix: use tree-sitter-markdown to extract code snippets (#1315)

Co-authored-by: yetone <yetoneful@gmail.com>
This commit is contained in:
zc he
2025-02-21 17:17:12 +08:00
committed by GitHub
parent fbda027cdc
commit caa8342508
2 changed files with 126 additions and 119 deletions

View File

@@ -548,63 +548,78 @@ end
---@field start_line_in_response_buf integer
---@field end_line_in_response_buf integer
---@field filepath string
---
---@param source string|integer
---@return TSNode[]
local function tree_sitter_markdown_parse_code_blocks(source)
local query = require("vim.treesitter.query")
local parser
if type(source) == "string" then
parser = vim.treesitter.get_string_parser(source, "markdown")
else
parser = vim.treesitter.get_parser(source, "markdown")
end
local tree = parser:parse()[1]
local root = tree:root()
local code_block_query = query.parse(
"markdown",
[[ (fenced_code_block
(info_string
(language) @language)?
(code_fence_content) @code) ]]
)
local nodes = {}
for _, node in code_block_query:iter_captures(root, source) do
table.insert(nodes, node)
end
return nodes
end
---@param response_content string
---@return table<string, AvanteCodeSnippet[]>
local function extract_cursor_planning_code_snippets_map(response_content, current_filepath, current_filetype)
local snippets = {}
local current_snippet = {}
local in_code_block = false
local lang, filepath, start_line_in_response_buf
local lines = vim.split(response_content, "\n")
local cumulated_content = ""
local idx = 1
local line_count = #lines
while idx <= line_count do
local line = lines[idx]
if line:match("^%s*```") then
if in_code_block then
in_code_block = false
if filepath == nil or filepath == "" then
if lang == current_filetype then
filepath = current_filepath
else
Utils.warn(
string.format(
"Failed to parse filepath from code block, and current_filetype `%s` is not the same as the filetype `%s` of the current code block, so ignore this code block",
current_filetype,
lang
)
-- use tree-sitter-markdown to parse all code blocks in response_content
local lang = "unknown"
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(response_content)) do
if node:type() == "language" then
lang = vim.treesitter.get_node_text(node, response_content)
lang = vim.split(lang, ":")[1]
elseif node:type() == "code_fence_content" then
local start_line, _ = node:start()
local end_line, _ = node:end_()
local filepath, skip_next_line = obtain_filepath_from_codeblock(lines, start_line)
if filepath == nil or filepath == "" then
if lang == current_filetype then
filepath = current_filepath
else
Utils.warn(
string.format(
"Failed to parse filepath from code block, and current_filetype `%s` is not the same as the filetype `%s` of the current code block, so ignore this code block",
current_filetype,
lang
)
goto continue
end
end
table.insert(snippets, {
range = { 0, 0 },
content = table.concat(current_snippet, "\n"),
lang = lang,
filepath = filepath,
start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = idx,
})
else
in_code_block = true
start_line_in_response_buf = idx
local lang_ = line:match("^%s*```(%w+)")
lang = lang_ or "unknown"
local filepath_, skip_next_line = obtain_filepath_from_codeblock(lines, idx)
if filepath_ then
filepath = filepath_
if skip_next_line then idx = idx + 1 end
)
lang = "unknown"
goto continue
end
end
elseif in_code_block then
table.insert(current_snippet, line)
if skip_next_line then start_line = start_line + 1 end
local this_content = table.concat(vim.list_slice(lines, start_line + 1, end_line), "\n")
cumulated_content = cumulated_content .. "\n" .. this_content
table.insert(snippets, {
range = { 0, 0 },
content = cumulated_content,
lang = lang,
filepath = filepath,
start_line_in_response_buf = start_line,
end_line_in_response_buf = end_line + 1,
})
end
::continue::
idx = idx + 1
end
local snippets_map = {}
@@ -620,62 +635,61 @@ end
---@return table<string, AvanteCodeSnippet[]>
local function extract_code_snippets_map(response_content)
local snippets = {}
local current_snippet = {}
local in_code_block = false
local lang, start_line, end_line, start_line_in_response_buf
local explanation = ""
local lines = vim.split(response_content, "\n")
for idx, line in ipairs(lines) do
local _, start_line_str, end_line_str =
line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
if start_line_str ~= nil and end_line_str ~= nil then
start_line = tonumber(start_line_str)
end_line = tonumber(end_line_str)
else
_, start_line_str = line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ine:?%s*(%d+)")
if start_line_str ~= nil then
-- use tree-sitter-markdown to parse all code blocks in response_content
local lang = "text"
local explanation_start_line = 0
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(response_content)) do
local start_line_in_response_buf, _ = node:start()
local end_line_in_response_buf, _ = node:end_()
if node:type() == "language" then
lang = vim.treesitter.get_node_text(node, response_content)
elseif node:type() == "code_fence_content" and start_line_in_response_buf > 1 then
local number_line = lines[start_line_in_response_buf - 1]
local start_line, end_line
local _, start_line_str, end_line_str =
number_line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
if start_line_str ~= nil and end_line_str ~= nil then
start_line = tonumber(start_line_str)
end_line = tonumber(start_line_str)
end_line = tonumber(end_line_str)
else
start_line_str = line:match("[Aa]fter%s+[Ll]ine:?%s*(%d+)")
_, start_line_str = number_line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ine:?%s*(%d+)")
if start_line_str ~= nil then
start_line = tonumber(start_line_str) + 1
end_line = tonumber(start_line_str) + 1
start_line = tonumber(start_line_str)
end_line = tonumber(start_line_str)
else
start_line_str = number_line:match("[Aa]fter%s+[Ll]ine:?%s*(%d+)")
if start_line_str ~= nil then
start_line = tonumber(start_line_str) + 1
end_line = tonumber(start_line_str) + 1
end
end
end
end
if line:match("^%s*```") then
if in_code_block then
if start_line ~= nil and end_line ~= nil then
local filepath = lines[start_line_in_response_buf - 2]
if filepath:match("^[Ff]ilepath:") then filepath = filepath:match("^[Ff]ilepath:%s*(.+)") end
local snippet = {
range = { start_line, end_line },
content = table.concat(current_snippet, "\n"),
lang = lang,
explanation = explanation,
start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = idx,
filepath = filepath,
}
table.insert(snippets, snippet)
if start_line ~= nil and end_line ~= nil then
local filepath = lines[start_line_in_response_buf - 2]
if filepath:match("^[Ff]ilepath:") then filepath = filepath:match("^[Ff]ilepath:%s*(.+)") end
local content = vim.treesitter.get_node_text(node, response_content)
local explanation = ""
if start_line_in_response_buf > explanation_start_line + 2 then
explanation =
table.concat(vim.list_slice(lines, explanation_start_line, start_line_in_response_buf - 3), "\n")
end
current_snippet = {}
start_line, end_line = nil, nil
explanation = ""
in_code_block = false
else
lang = line:match("^%s*```(%w+)")
if not lang or lang == "" then lang = "text" end
in_code_block = true
start_line_in_response_buf = idx
local snippet = {
range = { start_line, end_line },
content = content,
lang = lang,
explanation = explanation,
start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = end_line_in_response_buf + 1,
filepath = filepath,
}
table.insert(snippets, snippet)
end
elseif in_code_block then
table.insert(current_snippet, line)
else
explanation = explanation .. line .. "\n"
lang = "text"
explanation_start_line = end_line_in_response_buf + 2
end
end
@@ -843,35 +857,24 @@ end
---@return AvanteCodeblock[]
local function parse_codeblocks(buf, current_filepath, current_filetype)
local codeblocks = {}
local in_codeblock = false
local start_line = nil
local lang = nil
local lines = Utils.get_buf_lines(0, -1, buf)
for i, line in ipairs(lines) do
if line:match("^%s*```") then
-- parse language
local lang_ = line:match("^%s*```(%w+)")
if in_codeblock and not lang_ then
table.insert(codeblocks, { start_line = start_line, end_line = i, lang = lang })
in_codeblock = false
elseif lang_ then
if Config.behaviour.enable_cursor_planning_mode then
local filepath = obtain_filepath_from_codeblock(lines, i)
if not filepath and lang_ == current_filetype then filepath = current_filepath end
if filepath then
lang = lang_
start_line = i
in_codeblock = true
end
else
if lines[i - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") then
lang = lang_
start_line = i
in_codeblock = true
end
local lang, valid
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(buf)) do
if node:type() == "language" then
lang = vim.treesitter.get_node_text(node, buf)
elseif node:type() == "code_fence_content" then
local start_line, _ = node:start()
local end_line, _ = node:end_()
if Config.behaviour.enable_cursor_planning_mode then
local filepath = obtain_filepath_from_codeblock(lines, start_line)
if not filepath and lang == current_filetype then filepath = current_filepath end
if filepath then valid = true end
else
if lines[start_line - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") then
valid = true
end
end
if valid then table.insert(codeblocks, { start_line = start_line, end_line = end_line + 1, lang = lang }) end
end
end