fix: use tree-sitter-markdown to extract code snippets (#1315)
Co-authored-by: yetone <yetoneful@gmail.com>
This commit is contained in:
@@ -73,6 +73,7 @@ For building binary if you wish to build from source, then `cargo` is required.
|
||||
build = "make",
|
||||
-- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows
|
||||
dependencies = {
|
||||
"nvim-treesitter/nvim-treesitter",
|
||||
"stevearc/dressing.nvim",
|
||||
"nvim-lua/plenary.nvim",
|
||||
"MunifTanjim/nui.nvim",
|
||||
@@ -121,6 +122,7 @@ For building binary if you wish to build from source, then `cargo` is required.
|
||||
```vim
|
||||
|
||||
" Deps
|
||||
Plug 'nvim-treesitter/nvim-treesitter'
|
||||
Plug 'stevearc/dressing.nvim'
|
||||
Plug 'nvim-lua/plenary.nvim'
|
||||
Plug 'MunifTanjim/nui.nvim'
|
||||
@@ -153,6 +155,7 @@ add({
|
||||
source = 'yetone/avante.nvim',
|
||||
monitor = 'main',
|
||||
depends = {
|
||||
'nvim-treesitter/nvim-treesitter',
|
||||
'stevearc/dressing.nvim',
|
||||
'nvim-lua/plenary.nvim',
|
||||
'MunifTanjim/nui.nvim',
|
||||
@@ -184,6 +187,7 @@ end)
|
||||
```vim
|
||||
|
||||
-- Required plugins
|
||||
use 'nvim-treesitter/nvim-treesitter'
|
||||
use 'stevearc/dressing.nvim'
|
||||
use 'nvim-lua/plenary.nvim'
|
||||
use 'MunifTanjim/nui.nvim'
|
||||
|
||||
@@ -548,63 +548,78 @@ end
|
||||
---@field start_line_in_response_buf integer
|
||||
---@field end_line_in_response_buf integer
|
||||
---@field filepath string
|
||||
---
|
||||
|
||||
---@param source string|integer
|
||||
---@return TSNode[]
|
||||
local function tree_sitter_markdown_parse_code_blocks(source)
|
||||
local query = require("vim.treesitter.query")
|
||||
local parser
|
||||
if type(source) == "string" then
|
||||
parser = vim.treesitter.get_string_parser(source, "markdown")
|
||||
else
|
||||
parser = vim.treesitter.get_parser(source, "markdown")
|
||||
end
|
||||
local tree = parser:parse()[1]
|
||||
local root = tree:root()
|
||||
local code_block_query = query.parse(
|
||||
"markdown",
|
||||
[[ (fenced_code_block
|
||||
(info_string
|
||||
(language) @language)?
|
||||
(code_fence_content) @code) ]]
|
||||
)
|
||||
local nodes = {}
|
||||
for _, node in code_block_query:iter_captures(root, source) do
|
||||
table.insert(nodes, node)
|
||||
end
|
||||
return nodes
|
||||
end
|
||||
|
||||
---@param response_content string
|
||||
---@return table<string, AvanteCodeSnippet[]>
|
||||
local function extract_cursor_planning_code_snippets_map(response_content, current_filepath, current_filetype)
|
||||
local snippets = {}
|
||||
local current_snippet = {}
|
||||
local in_code_block = false
|
||||
local lang, filepath, start_line_in_response_buf
|
||||
|
||||
local lines = vim.split(response_content, "\n")
|
||||
local cumulated_content = ""
|
||||
|
||||
local idx = 1
|
||||
local line_count = #lines
|
||||
|
||||
while idx <= line_count do
|
||||
local line = lines[idx]
|
||||
if line:match("^%s*```") then
|
||||
if in_code_block then
|
||||
in_code_block = false
|
||||
if filepath == nil or filepath == "" then
|
||||
if lang == current_filetype then
|
||||
filepath = current_filepath
|
||||
else
|
||||
Utils.warn(
|
||||
string.format(
|
||||
"Failed to parse filepath from code block, and current_filetype `%s` is not the same as the filetype `%s` of the current code block, so ignore this code block",
|
||||
current_filetype,
|
||||
lang
|
||||
)
|
||||
-- use tree-sitter-markdown to parse all code blocks in response_content
|
||||
local lang = "unknown"
|
||||
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(response_content)) do
|
||||
if node:type() == "language" then
|
||||
lang = vim.treesitter.get_node_text(node, response_content)
|
||||
lang = vim.split(lang, ":")[1]
|
||||
elseif node:type() == "code_fence_content" then
|
||||
local start_line, _ = node:start()
|
||||
local end_line, _ = node:end_()
|
||||
local filepath, skip_next_line = obtain_filepath_from_codeblock(lines, start_line)
|
||||
if filepath == nil or filepath == "" then
|
||||
if lang == current_filetype then
|
||||
filepath = current_filepath
|
||||
else
|
||||
Utils.warn(
|
||||
string.format(
|
||||
"Failed to parse filepath from code block, and current_filetype `%s` is not the same as the filetype `%s` of the current code block, so ignore this code block",
|
||||
current_filetype,
|
||||
lang
|
||||
)
|
||||
goto continue
|
||||
end
|
||||
end
|
||||
table.insert(snippets, {
|
||||
range = { 0, 0 },
|
||||
content = table.concat(current_snippet, "\n"),
|
||||
lang = lang,
|
||||
filepath = filepath,
|
||||
start_line_in_response_buf = start_line_in_response_buf,
|
||||
end_line_in_response_buf = idx,
|
||||
})
|
||||
else
|
||||
in_code_block = true
|
||||
start_line_in_response_buf = idx
|
||||
local lang_ = line:match("^%s*```(%w+)")
|
||||
lang = lang_ or "unknown"
|
||||
local filepath_, skip_next_line = obtain_filepath_from_codeblock(lines, idx)
|
||||
if filepath_ then
|
||||
filepath = filepath_
|
||||
if skip_next_line then idx = idx + 1 end
|
||||
)
|
||||
lang = "unknown"
|
||||
goto continue
|
||||
end
|
||||
end
|
||||
elseif in_code_block then
|
||||
table.insert(current_snippet, line)
|
||||
if skip_next_line then start_line = start_line + 1 end
|
||||
local this_content = table.concat(vim.list_slice(lines, start_line + 1, end_line), "\n")
|
||||
cumulated_content = cumulated_content .. "\n" .. this_content
|
||||
table.insert(snippets, {
|
||||
range = { 0, 0 },
|
||||
content = cumulated_content,
|
||||
lang = lang,
|
||||
filepath = filepath,
|
||||
start_line_in_response_buf = start_line,
|
||||
end_line_in_response_buf = end_line + 1,
|
||||
})
|
||||
end
|
||||
::continue::
|
||||
idx = idx + 1
|
||||
end
|
||||
|
||||
local snippets_map = {}
|
||||
@@ -620,62 +635,61 @@ end
|
||||
---@return table<string, AvanteCodeSnippet[]>
|
||||
local function extract_code_snippets_map(response_content)
|
||||
local snippets = {}
|
||||
local current_snippet = {}
|
||||
local in_code_block = false
|
||||
local lang, start_line, end_line, start_line_in_response_buf
|
||||
local explanation = ""
|
||||
|
||||
local lines = vim.split(response_content, "\n")
|
||||
|
||||
for idx, line in ipairs(lines) do
|
||||
local _, start_line_str, end_line_str =
|
||||
line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
|
||||
if start_line_str ~= nil and end_line_str ~= nil then
|
||||
start_line = tonumber(start_line_str)
|
||||
end_line = tonumber(end_line_str)
|
||||
else
|
||||
_, start_line_str = line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ine:?%s*(%d+)")
|
||||
if start_line_str ~= nil then
|
||||
-- use tree-sitter-markdown to parse all code blocks in response_content
|
||||
local lang = "text"
|
||||
local explanation_start_line = 0
|
||||
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(response_content)) do
|
||||
local start_line_in_response_buf, _ = node:start()
|
||||
local end_line_in_response_buf, _ = node:end_()
|
||||
if node:type() == "language" then
|
||||
lang = vim.treesitter.get_node_text(node, response_content)
|
||||
elseif node:type() == "code_fence_content" and start_line_in_response_buf > 1 then
|
||||
local number_line = lines[start_line_in_response_buf - 1]
|
||||
local start_line, end_line
|
||||
|
||||
local _, start_line_str, end_line_str =
|
||||
number_line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
|
||||
if start_line_str ~= nil and end_line_str ~= nil then
|
||||
start_line = tonumber(start_line_str)
|
||||
end_line = tonumber(start_line_str)
|
||||
end_line = tonumber(end_line_str)
|
||||
else
|
||||
start_line_str = line:match("[Aa]fter%s+[Ll]ine:?%s*(%d+)")
|
||||
_, start_line_str = number_line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ine:?%s*(%d+)")
|
||||
if start_line_str ~= nil then
|
||||
start_line = tonumber(start_line_str) + 1
|
||||
end_line = tonumber(start_line_str) + 1
|
||||
start_line = tonumber(start_line_str)
|
||||
end_line = tonumber(start_line_str)
|
||||
else
|
||||
start_line_str = number_line:match("[Aa]fter%s+[Ll]ine:?%s*(%d+)")
|
||||
if start_line_str ~= nil then
|
||||
start_line = tonumber(start_line_str) + 1
|
||||
end_line = tonumber(start_line_str) + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if line:match("^%s*```") then
|
||||
if in_code_block then
|
||||
if start_line ~= nil and end_line ~= nil then
|
||||
local filepath = lines[start_line_in_response_buf - 2]
|
||||
if filepath:match("^[Ff]ilepath:") then filepath = filepath:match("^[Ff]ilepath:%s*(.+)") end
|
||||
local snippet = {
|
||||
range = { start_line, end_line },
|
||||
content = table.concat(current_snippet, "\n"),
|
||||
lang = lang,
|
||||
explanation = explanation,
|
||||
start_line_in_response_buf = start_line_in_response_buf,
|
||||
end_line_in_response_buf = idx,
|
||||
filepath = filepath,
|
||||
}
|
||||
table.insert(snippets, snippet)
|
||||
|
||||
if start_line ~= nil and end_line ~= nil then
|
||||
local filepath = lines[start_line_in_response_buf - 2]
|
||||
if filepath:match("^[Ff]ilepath:") then filepath = filepath:match("^[Ff]ilepath:%s*(.+)") end
|
||||
local content = vim.treesitter.get_node_text(node, response_content)
|
||||
local explanation = ""
|
||||
if start_line_in_response_buf > explanation_start_line + 2 then
|
||||
explanation =
|
||||
table.concat(vim.list_slice(lines, explanation_start_line, start_line_in_response_buf - 3), "\n")
|
||||
end
|
||||
current_snippet = {}
|
||||
start_line, end_line = nil, nil
|
||||
explanation = ""
|
||||
in_code_block = false
|
||||
else
|
||||
lang = line:match("^%s*```(%w+)")
|
||||
if not lang or lang == "" then lang = "text" end
|
||||
in_code_block = true
|
||||
start_line_in_response_buf = idx
|
||||
local snippet = {
|
||||
range = { start_line, end_line },
|
||||
content = content,
|
||||
lang = lang,
|
||||
explanation = explanation,
|
||||
start_line_in_response_buf = start_line_in_response_buf,
|
||||
end_line_in_response_buf = end_line_in_response_buf + 1,
|
||||
filepath = filepath,
|
||||
}
|
||||
table.insert(snippets, snippet)
|
||||
end
|
||||
elseif in_code_block then
|
||||
table.insert(current_snippet, line)
|
||||
else
|
||||
explanation = explanation .. line .. "\n"
|
||||
lang = "text"
|
||||
explanation_start_line = end_line_in_response_buf + 2
|
||||
end
|
||||
end
|
||||
|
||||
@@ -843,35 +857,24 @@ end
|
||||
---@return AvanteCodeblock[]
|
||||
local function parse_codeblocks(buf, current_filepath, current_filetype)
|
||||
local codeblocks = {}
|
||||
local in_codeblock = false
|
||||
local start_line = nil
|
||||
local lang = nil
|
||||
|
||||
local lines = Utils.get_buf_lines(0, -1, buf)
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^%s*```") then
|
||||
-- parse language
|
||||
local lang_ = line:match("^%s*```(%w+)")
|
||||
if in_codeblock and not lang_ then
|
||||
table.insert(codeblocks, { start_line = start_line, end_line = i, lang = lang })
|
||||
in_codeblock = false
|
||||
elseif lang_ then
|
||||
if Config.behaviour.enable_cursor_planning_mode then
|
||||
local filepath = obtain_filepath_from_codeblock(lines, i)
|
||||
if not filepath and lang_ == current_filetype then filepath = current_filepath end
|
||||
if filepath then
|
||||
lang = lang_
|
||||
start_line = i
|
||||
in_codeblock = true
|
||||
end
|
||||
else
|
||||
if lines[i - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") then
|
||||
lang = lang_
|
||||
start_line = i
|
||||
in_codeblock = true
|
||||
end
|
||||
local lang, valid
|
||||
for _, node in ipairs(tree_sitter_markdown_parse_code_blocks(buf)) do
|
||||
if node:type() == "language" then
|
||||
lang = vim.treesitter.get_node_text(node, buf)
|
||||
elseif node:type() == "code_fence_content" then
|
||||
local start_line, _ = node:start()
|
||||
local end_line, _ = node:end_()
|
||||
if Config.behaviour.enable_cursor_planning_mode then
|
||||
local filepath = obtain_filepath_from_codeblock(lines, start_line)
|
||||
if not filepath and lang == current_filetype then filepath = current_filepath end
|
||||
if filepath then valid = true end
|
||||
else
|
||||
if lines[start_line - 1]:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)") then
|
||||
valid = true
|
||||
end
|
||||
end
|
||||
if valid then table.insert(codeblocks, { start_line = start_line, end_line = end_line + 1, lang = lang }) end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user