feat: streaming diff (#2107)
This commit is contained in:
@@ -150,30 +150,30 @@ function M.generate_prompts(opts)
|
||||
local tool_id_to_tool_name = {}
|
||||
local tool_id_to_path = {}
|
||||
local viewed_files = {}
|
||||
local last_modified_files = {}
|
||||
local history_messages = {}
|
||||
if opts.history_messages then
|
||||
for _, message in ipairs(opts.history_messages) do
|
||||
for idx, message in ipairs(opts.history_messages) do
|
||||
if Utils.is_tool_result_message(message) then
|
||||
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
|
||||
local is_replace_func_call, _, _, path = Utils.is_replace_func_call_message(tool_use_message)
|
||||
|
||||
if is_replace_func_call and path and not message.message.content[1].is_error then
|
||||
local uniformed_path = Utils.uniform_path(path)
|
||||
last_modified_files[uniformed_path] = idx
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for idx, message in ipairs(opts.history_messages) do
|
||||
table.insert(history_messages, message)
|
||||
if Utils.is_tool_result_message(message) then
|
||||
local tool_use_message = Utils.get_tool_use_message(message, opts.history_messages)
|
||||
local is_replace_func_call = false
|
||||
local is_str_replace_editor_func_call = false
|
||||
local path = nil
|
||||
if tool_use_message then
|
||||
if tool_use_message.message.content[1].name == "replace_in_file" then
|
||||
is_replace_func_call = true
|
||||
path = tool_use_message.message.content[1].input.path
|
||||
end
|
||||
if tool_use_message.message.content[1].name == "str_replace_editor" then
|
||||
if tool_use_message.message.content[1].input.command == "str_replace" then
|
||||
is_replace_func_call = true
|
||||
is_str_replace_editor_func_call = true
|
||||
path = tool_use_message.message.content[1].input.path
|
||||
end
|
||||
end
|
||||
end
|
||||
local is_replace_func_call, is_str_replace_editor_func_call, is_str_replace_based_edit_tool_func_call, path =
|
||||
Utils.is_replace_func_call_message(tool_use_message)
|
||||
--- For models like gpt-4o, the input parameter of replace_in_file is treated as the latest file content, so here we need to insert a fake view tool call to ensure it uses the latest file content
|
||||
if is_replace_func_call and path and not message.message.content[1].is_error then
|
||||
local uniformed_path = Utils.uniform_path(path)
|
||||
local view_result, view_error = require("avante.llm_tools.view").func({ path = path }, nil, nil, nil)
|
||||
if view_error then view_result = "Error: " .. view_error end
|
||||
local get_diagnostics_tool_use_id = Utils.uuid()
|
||||
@@ -184,7 +184,10 @@ function M.generate_prompts(opts)
|
||||
view_tool_name = "str_replace_editor"
|
||||
view_tool_input = { command = "view", path = path }
|
||||
end
|
||||
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
|
||||
if is_str_replace_based_edit_tool_func_call then
|
||||
view_tool_name = "str_replace_based_edit_tool"
|
||||
view_tool_input = { command = "view", path = path }
|
||||
end
|
||||
history_messages = vim.list_extend(history_messages, {
|
||||
HistoryMessage:new({
|
||||
role = "assistant",
|
||||
@@ -218,42 +221,47 @@ function M.generate_prompts(opts)
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = string.format(
|
||||
"The file %s has been modified, let me check if there are any errors in the changes.",
|
||||
path
|
||||
),
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
id = get_diagnostics_tool_use_id,
|
||||
name = "get_diagnostics",
|
||||
input = { path = path },
|
||||
},
|
||||
},
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
HistoryMessage:new({
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = get_diagnostics_tool_use_id,
|
||||
content = vim.json.encode(diagnostics),
|
||||
is_error = false,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
})
|
||||
if last_modified_files[uniformed_path] == idx then
|
||||
local diagnostics = Utils.lsp.get_diagnostics_from_filepath(path)
|
||||
history_messages = vim.list_extend(history_messages, {
|
||||
HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = string.format(
|
||||
"The file %s has been modified, let me check if there are any errors in the changes.",
|
||||
path
|
||||
),
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
id = get_diagnostics_tool_use_id,
|
||||
name = "get_diagnostics",
|
||||
input = { path = path },
|
||||
},
|
||||
},
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
HistoryMessage:new({
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = get_diagnostics_tool_use_id,
|
||||
content = vim.json.encode(diagnostics),
|
||||
is_error = false,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
is_dummy = true,
|
||||
}),
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -418,6 +426,23 @@ function M.generate_prompts(opts)
|
||||
local messages = vim.deepcopy(context_messages)
|
||||
for _, msg in ipairs(final_history_messages) do
|
||||
local message = msg.message
|
||||
if msg.is_user_submission then
|
||||
message = vim.deepcopy(message)
|
||||
local content = message.content
|
||||
if type(content) == "string" then
|
||||
message.content = "<task>" .. content .. "</task>"
|
||||
elseif type(content) == "table" then
|
||||
for idx, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
item = "<task>" .. item .. "</task>"
|
||||
content[idx] = item
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
item.content = "<task>" .. item.content .. "</task>"
|
||||
content[idx] = item
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(messages, message)
|
||||
end
|
||||
|
||||
@@ -741,11 +766,11 @@ function M._stream(opts)
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param partial_tool_use_list AvantePartialLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_results AvanteLLMToolResult[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
|
||||
if tool_use_index > #tool_use_list then
|
||||
local function handle_next_tool_use(partial_tool_use_list, tool_use_index, tool_results, streaming_tool_use)
|
||||
if tool_use_index > #partial_tool_use_list then
|
||||
---@type avante.HistoryMessage[]
|
||||
local messages = {}
|
||||
for _, tool_result in ipairs(tool_results) do
|
||||
@@ -762,7 +787,7 @@ function M._stream(opts)
|
||||
})
|
||||
end
|
||||
if opts.on_messages_add then opts.on_messages_add(messages) end
|
||||
local the_last_tool_use = tool_use_list[#tool_use_list]
|
||||
local the_last_tool_use = partial_tool_use_list[#partial_tool_use_list]
|
||||
if the_last_tool_use and the_last_tool_use.name == "attempt_completion" then
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
@@ -781,7 +806,7 @@ function M._stream(opts)
|
||||
M._stream(new_opts)
|
||||
return
|
||||
end
|
||||
local tool_use = tool_use_list[tool_use_index]
|
||||
local partial_tool_use = partial_tool_use_list[tool_use_index]
|
||||
---@param result string | nil
|
||||
---@param error string | nil
|
||||
local function handle_tool_result(result, error)
|
||||
@@ -802,17 +827,37 @@ function M._stream(opts)
|
||||
end
|
||||
|
||||
local tool_result = {
|
||||
tool_use_id = tool_use.id,
|
||||
tool_use_id = partial_tool_use.id,
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_results, tool_result)
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
|
||||
return handle_next_tool_use(partial_tool_use_list, tool_use_index + 1, tool_results)
|
||||
end
|
||||
local is_replace_func_call = Utils.is_replace_func_call_tool_use(partial_tool_use)
|
||||
if partial_tool_use.state == "generating" and not is_replace_func_call then return end
|
||||
if is_replace_func_call then
|
||||
if type(partial_tool_use.input) == "table" then partial_tool_use.input.tool_use_id = partial_tool_use.id end
|
||||
if partial_tool_use.state == "generating" then
|
||||
if type(partial_tool_use.input) == "table" then
|
||||
partial_tool_use.input.streaming = true
|
||||
LLMTools.process_tool_use(
|
||||
prompt_opts.tools,
|
||||
partial_tool_use,
|
||||
function() end,
|
||||
function() end,
|
||||
opts.session_ctx
|
||||
)
|
||||
end
|
||||
return
|
||||
else
|
||||
if streaming_tool_use then return end
|
||||
end
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(
|
||||
prompt_opts.tools,
|
||||
tool_use,
|
||||
partial_tool_use,
|
||||
opts.on_tool_log,
|
||||
handle_tool_result,
|
||||
opts.session_ctx
|
||||
@@ -832,7 +877,7 @@ function M._stream(opts)
|
||||
end
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
local tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
|
||||
local tool_result_seen = {}
|
||||
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
|
||||
for idx = #history_messages, 1, -1 do
|
||||
@@ -843,7 +888,13 @@ function M._stream(opts)
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "tool_use" then
|
||||
if not tool_result_seen[item.id] then
|
||||
table.insert(tool_use_list, 1, item)
|
||||
local partial_tool_use = {
|
||||
name = item.name,
|
||||
id = item.id,
|
||||
input = item.input,
|
||||
state = message.state,
|
||||
}
|
||||
table.insert(partial_tool_use_list, 1, partial_tool_use)
|
||||
else
|
||||
is_break = true
|
||||
break
|
||||
@@ -855,7 +906,7 @@ function M._stream(opts)
|
||||
::continue::
|
||||
end
|
||||
if stop_opts.reason == "complete" and Config.mode == "agentic" then
|
||||
if #tool_use_list == 0 then
|
||||
if #partial_tool_use_list == 0 then
|
||||
local completed_attempt_completion_tool_use = nil
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
@@ -892,7 +943,9 @@ function M._stream(opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
if stop_opts.reason == "tool_use" then return handle_next_tool_use(tool_use_list, 1, {}) end
|
||||
if stop_opts.reason == "tool_use" then
|
||||
return handle_next_tool_use(partial_tool_use_list, 1, {}, stop_opts.streaming_tool_use)
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
|
||||
|
||||
Reference in New Issue
Block a user