feat: streaming diff (#2107)

This commit is contained in:
yetone
2025-06-02 16:44:33 +08:00
committed by GitHub
parent bc403ddcbf
commit 746f071b37
12 changed files with 1449 additions and 130 deletions

View File

@@ -150,30 +150,30 @@ function M.generate_prompts(opts)
local tool_id_to_tool_name = {}
local tool_id_to_path = {}
local viewed_files = {}
local last_modified_files = {}
local history_messages = {}
if opts.history_messages then
for _, message in ipairs(opts.history_messages) do
for idx, message in ipairs(opts.history_messages) do
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
local is_replace_func_call, _, _, path = Utils.is_replace_func_call_message(tool_use_message)
if is_replace_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
last_modified_files[uniformed_path] = idx
end
end
end
for idx, message in ipairs(opts.history_messages) do
table.insert(history_messages, message)
if Utils.is_tool_result_message(message) then
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
local is_replace_func_call = false
local is_str_replace_editor_func_call = false
local path = nil
if tool_use_message then
if tool_use_message.message.content[1].name == "replace_in_file" then
is_replace_func_call = true
path = tool_use_message.message.content[1].input.path
end
if tool_use_message.message.content[1].name == "str_replace_editor" then
if tool_use_message.message.content[1].input.command == "str_replace" then
is_replace_func_call = true
is_str_replace_editor_func_call = true
path = tool_use_message.message.content[1].input.path
end
end
end
local is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path =
Utils.is_replace_func_call_message(tool_use_message)
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
if is_replace_func_call and path and not message.message.content[1].is_error then
local uniformed_path = Utils.uniform_path(path)
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
if view_error then view_result = "Error: " .. view_error end
local get_diagnostics_tool_use_id = Utils.uuid()
@@ -184,7 +184,10 @@ function M.generate_prompts(opts)
view_tool_name = "str_replace_editor"
view_tool_input = { command = "view", path = path }
end
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
if is_str_replace_based_edit_tool_func_call then
view_tool_name = "str_replace_based_edit_tool"
view_tool_input = { command = "view", path = path }
end
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
@@ -218,42 +221,47 @@ function M.generate_prompts(opts)
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = string.format(
"The file %s has been modified, let me check if there are any errors in the changes.",
path
),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = get_diagnostics_tool_use_id,
name = "get_diagnostics",
input = { path = path },
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = get_diagnostics_tool_use_id,
content = vim.json.encode(diagnostics),
is_error = false,
},
},
}, {
is_dummy = true,
}),
})
if last_modified_files[uniformed_path] == idx then
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
history_messages = vim.list_extend(history_messages, {
HistoryMessage:new({
role = "assistant",
content = string.format(
"The file %s has been modified, let me check if there are any errors in the changes.",
path
),
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "assistant",
content = {
{
type = "tool_use",
id = get_diagnostics_tool_use_id,
name = "get_diagnostics",
input = { path = path },
},
},
}, {
is_dummy = true,
}),
HistoryMessage:new({
role = "user",
content = {
{
type = "tool_result",
tool_use_id = get_diagnostics_tool_use_id,
content = vim.json.encode(diagnostics),
is_error = false,
},
},
}, {
is_dummy = true,
}),
})
end
end
end
end
@@ -418,6 +426,23 @@ function M.generate_prompts(opts)
local messages = vim.deepcopy(context_messages)
for _, msg in ipairs(final_history_messages) do
local message = msg.message
if msg.is_user_submission then
message = vim.deepcopy(message)
local content = message.content
if type(content) == "string" then
message.content = "<task>" .. content .. "</task>"
elseif type(content) == "table" then
for idx, item in ipairs(content) do
if type(item) == "string" then
item = "<task>" .. item .. "</task>"
content[idx] = item
elseif type(item) == "table" and item.type == "text" then
item.content = "<task>" .. item.content .. "</task>"
content[idx] = item
end
end
end
end
table.insert(messages, message)
end
@@ -741,11 +766,11 @@ function M._stream(opts)
on_start = opts.on_start,
on_chunk = opts.on_chunk,
on_stop = function(stop_opts)
---@param tool_use_list AvanteLLMToolUse[]
---@param partial_tool_use_list AvantePartialLLMToolUse[]
---@param tool_use_index integer
---@param tool_results AvanteLLMToolResult[]
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
if tool_use_index > #tool_use_list then
local function handle_next_tool_use(partial_tool_use_list, tool_use_index, tool_results, streaming_tool_use)
if tool_use_index > #partial_tool_use_list then
---@type avante.HistoryMessage[]
local messages = {}
for _, tool_result in ipairs(tool_results) do
@@ -762,7 +787,7 @@ function M._stream(opts)
})
end
if opts.on_messages_add then opts.on_messages_add(messages) end
local the_last_tool_use = tool_use_list[#tool_use_list]
local the_last_tool_use = partial_tool_use_list[#partial_tool_use_list]
if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then
opts.on_stop({ reason = "complete" })
return
@@ -781,7 +806,7 @@ function M._stream(opts)
M._stream(new_opts)
return
end
local tool_use = tool_use_list[tool_use_index]
local partial_tool_use = partial_tool_use_list[tool_use_index]
---@param result string | nil
---@param error string | nil
local function handle_tool_result(result, error)
@@ -802,17 +827,37 @@ function M._stream(opts)
end
local tool_result = {
tool_use_id = tool_use.id,
tool_use_id = partial_tool_use.id,
content = error ~= nil and error or result,
is_error = error ~= nil,
}
table.insert(tool_results, tool_result)
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
return handle_next_tool_use(partial_tool_use_list, tool_use_index + 1, tool_results)
end
local is_replace_func_call = Utils.is_replace_func_call_tool_use(partial_tool_use)
if partial_tool_use.state == "generating" and not is_replace_func_call then return end
if is_replace_func_call then
if type(partial_tool_use.input) == "table" then partial_tool_use.input.tool_use_id = partial_tool_use.id end
if partial_tool_use.state == "generating" then
if type(partial_tool_use.input) == "table" then
partial_tool_use.input.streaming = true
LLMTools.process_tool_use(
prompt_opts.tools,
partial_tool_use,
function() end,
function() end,
opts.session_ctx
)
end
return
else
if streaming_tool_use then return end
end
end
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
local result, error = LLMTools.process_tool_use(
prompt_opts.tools,
tool_use,
partial_tool_use,
opts.on_tool_log,
handle_tool_result,
opts.session_ctx
@@ -832,7 +877,7 @@ function M._stream(opts)
end
return opts.on_stop({ reason = "cancelled" })
end
local tool_use_list = {} ---@type AvanteLLMToolUse[]
local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
local tool_result_seen = {}
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
for idx = #history_messages, 1, -1 do
@@ -843,7 +888,13 @@ function M._stream(opts)
for _, item in ipairs(content) do
if item.type == "tool_use" then
if not tool_result_seen[item.id] then
table.insert(tool_use_list, 1, item)
local partial_tool_use = {
name = item.name,
id = item.id,
input = item.input,
state = message.state,
}
table.insert(partial_tool_use_list, 1, partial_tool_use)
else
is_break = true
break
@@ -855,7 +906,7 @@ function M._stream(opts)
::continue::
end
if stop_opts.reason == "complete" and Config.mode == "agentic" then
if #tool_use_list == 0 then
if #partial_tool_use_list == 0 then
local completed_attempt_completion_tool_use = nil
for idx = #history_messages, 1, -1 do
local message = history_messages[idx]
@@ -892,7 +943,9 @@ function M._stream(opts)
end
end
end
if stop_opts.reason == "tool_use" then return handle_next_tool_use(tool_use_list, 1, {}) end
if stop_opts.reason == "tool_use" then
return handle_next_tool_use(partial_tool_use_list, 1, {}, stop_opts.streaming_tool_use)
end
if stop_opts.reason == "rate_limit" then
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end