fix: tool calling (#2297)
This commit is contained in:
@@ -244,6 +244,7 @@ function M.select_history()
|
|||||||
Path.history.save_latest_filename(buf, filename)
|
Path.history.save_latest_filename(buf, filename)
|
||||||
local sidebar = require("avante").get()
|
local sidebar = require("avante").get()
|
||||||
sidebar:update_content_with_history()
|
sidebar:update_content_with_history()
|
||||||
|
sidebar:create_todos_container()
|
||||||
vim.schedule(function() sidebar:focus_input() end)
|
vim.schedule(function() sidebar:focus_input() end)
|
||||||
end)
|
end)
|
||||||
end)
|
end)
|
||||||
|
|||||||
@@ -272,9 +272,10 @@ M._defaults = {
|
|||||||
endpoint = "https://api.anthropic.com",
|
endpoint = "https://api.anthropic.com",
|
||||||
model = "claude-sonnet-4-20250514",
|
model = "claude-sonnet-4-20250514",
|
||||||
timeout = 30000, -- Timeout in milliseconds
|
timeout = 30000, -- Timeout in milliseconds
|
||||||
|
context_window = 200000,
|
||||||
extra_request_body = {
|
extra_request_body = {
|
||||||
temperature = 0.75,
|
temperature = 0.75,
|
||||||
max_tokens = 20480,
|
max_tokens = 64000,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ local M = {}
|
|||||||
M.__index = M
|
M.__index = M
|
||||||
|
|
||||||
---@param message AvanteLLMMessage
|
---@param message AvanteLLMMessage
|
||||||
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, session_id?: string}
|
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, turn_id?: string}
|
||||||
---@return avante.HistoryMessage
|
---@return avante.HistoryMessage
|
||||||
function M:new(message, opts)
|
function M:new(message, opts)
|
||||||
opts = opts or {}
|
opts = opts or {}
|
||||||
@@ -23,7 +23,7 @@ function M:new(message, opts)
|
|||||||
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
|
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
|
||||||
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
|
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
|
||||||
if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end
|
if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end
|
||||||
obj.session_id = opts.session_id
|
obj.turn_id = opts.turn_id
|
||||||
return obj
|
return obj
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -416,13 +416,24 @@ function M.curl(opts)
|
|||||||
local prompt_opts = opts.prompt_opts
|
local prompt_opts = opts.prompt_opts
|
||||||
local handler_opts = opts.handler_opts
|
local handler_opts = opts.handler_opts
|
||||||
|
|
||||||
|
local orig_on_stop = handler_opts.on_stop
|
||||||
|
local stopped = false
|
||||||
|
---@param stop_opts AvanteLLMStopCallbackOptions
|
||||||
|
handler_opts.on_stop = function(stop_opts)
|
||||||
|
if stop_opts and not stop_opts.streaming_tool_use then
|
||||||
|
if stopped then return end
|
||||||
|
stopped = true
|
||||||
|
end
|
||||||
|
if orig_on_stop then return orig_on_stop(stop_opts) end
|
||||||
|
end
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
---@type AvanteCurlOutput
|
||||||
local spec = provider:parse_curl_args(prompt_opts)
|
local spec = provider:parse_curl_args(prompt_opts)
|
||||||
|
|
||||||
---@type string
|
---@type string
|
||||||
local current_event_state = nil
|
local current_event_state = nil
|
||||||
local resp_ctx = {}
|
local turn_ctx = {}
|
||||||
resp_ctx.session_id = Utils.uuid()
|
turn_ctx.turn_id = Utils.uuid()
|
||||||
|
|
||||||
local response_body = ""
|
local response_body = ""
|
||||||
---@param line string
|
---@param line string
|
||||||
@@ -435,7 +446,7 @@ function M.curl(opts)
|
|||||||
local data_match = line:match("^data:%s*(.+)$")
|
local data_match = line:match("^data:%s*(.+)$")
|
||||||
if data_match then
|
if data_match then
|
||||||
response_body = ""
|
response_body = ""
|
||||||
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
|
provider:parse_response(turn_ctx, data_match, current_event_state, handler_opts)
|
||||||
else
|
else
|
||||||
response_body = response_body .. line
|
response_body = response_body .. line
|
||||||
local ok, jsn = pcall(vim.json.decode, response_body)
|
local ok, jsn = pcall(vim.json.decode, response_body)
|
||||||
@@ -443,7 +454,7 @@ function M.curl(opts)
|
|||||||
if jsn.error then
|
if jsn.error then
|
||||||
handler_opts.on_stop({ reason = "error", error = jsn.error })
|
handler_opts.on_stop({ reason = "error", error = jsn.error })
|
||||||
else
|
else
|
||||||
provider:parse_response(resp_ctx, response_body, current_event_state, handler_opts)
|
provider:parse_response(turn_ctx, response_body, current_event_state, handler_opts)
|
||||||
end
|
end
|
||||||
response_body = ""
|
response_body = ""
|
||||||
end
|
end
|
||||||
@@ -509,7 +520,7 @@ function M.curl(opts)
|
|||||||
end
|
end
|
||||||
vim.schedule(function()
|
vim.schedule(function()
|
||||||
if provider.parse_stream_data ~= nil then
|
if provider.parse_stream_data ~= nil then
|
||||||
provider:parse_stream_data(resp_ctx, data, handler_opts)
|
provider:parse_stream_data(turn_ctx, data, handler_opts)
|
||||||
else
|
else
|
||||||
parse_stream_data(data)
|
parse_stream_data(data)
|
||||||
end
|
end
|
||||||
@@ -843,6 +854,7 @@ function M._stream(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
if stop_opts.reason == "tool_use" then
|
if stop_opts.reason == "tool_use" then
|
||||||
|
opts.session_ctx.user_reminder_count = 0
|
||||||
return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use)
|
return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use)
|
||||||
end
|
end
|
||||||
if stop_opts.reason == "rate_limit" then
|
if stop_opts.reason == "rate_limit" then
|
||||||
|
|||||||
@@ -133,6 +133,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
|||||||
|
|
||||||
local is_streaming = opts.streaming or false
|
local is_streaming = opts.streaming or false
|
||||||
|
|
||||||
|
if is_streaming then return end
|
||||||
|
|
||||||
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
|
session_ctx.prev_streaming_diff_timestamp_map = session_ctx.prev_streaming_diff_timestamp_map or {}
|
||||||
local current_timestamp = os.time()
|
local current_timestamp = os.time()
|
||||||
if is_streaming then
|
if is_streaming then
|
||||||
|
|||||||
@@ -58,7 +58,10 @@ M.returns = {
|
|||||||
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
--- IMPORTANT: Using "the_content" instead of "content" is to avoid LLM streaming generating function parameters in alphabetical order, which would result in generating "path" after "content", making it impossible to achieve a stream diff view.
|
||||||
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
|
---@type AvanteLLMToolFunc<{ path: string, content: string, the_content?: string, streaming?: boolean, tool_use_id?: string }>
|
||||||
function M.func(opts, on_log, on_complete, session_ctx)
|
function M.func(opts, on_log, on_complete, session_ctx)
|
||||||
if opts.the_content ~= nil then opts.content = opts.the_content end
|
if opts.the_content ~= nil then
|
||||||
|
opts.content = opts.the_content
|
||||||
|
opts.the_content = nil
|
||||||
|
end
|
||||||
if not on_complete then return false, "on_complete not provided" end
|
if not on_complete then return false, "on_complete not provided" end
|
||||||
local abs_path = Helpers.get_abs_path(opts.path)
|
local abs_path = Helpers.get_abs_path(opts.path)
|
||||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
content = content_block.text,
|
content = content_block.text,
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
@@ -185,7 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
opts.on_messages_add({ msg })
|
opts.on_messages_add({ msg })
|
||||||
@@ -205,11 +205,11 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
content_block.uuid = msg.uuid
|
content_block.uuid = msg.uuid
|
||||||
opts.on_messages_add({ msg })
|
opts.on_messages_add({ msg })
|
||||||
opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
|
-- opts.on_stop({ reason = "tool_use", streaming_tool_use = true })
|
||||||
end
|
end
|
||||||
elseif event_state == "content_block_delta" then
|
elseif event_state == "content_block_delta" then
|
||||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||||
@@ -234,7 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
elseif jsn.delta.type == "text_delta" then
|
elseif jsn.delta.type == "text_delta" then
|
||||||
@@ -246,7 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generating",
|
state = "generating",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
elseif jsn.delta.type == "signature_delta" then
|
elseif jsn.delta.type == "signature_delta" then
|
||||||
@@ -265,7 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
@@ -284,7 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
@@ -308,7 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = "generated",
|
state = "generated",
|
||||||
uuid = content_block.uuid,
|
uuid = content_block.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
end
|
end
|
||||||
@@ -317,6 +317,8 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
|||||||
if not ok then return end
|
if not ok then return end
|
||||||
if jsn.delta.stop_reason == "end_turn" then
|
if jsn.delta.stop_reason == "end_turn" then
|
||||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||||
|
elseif jsn.delta.stop_reason == "max_tokens" then
|
||||||
|
opts.on_stop({ reason = "max_tokens", usage = jsn.usage })
|
||||||
elseif jsn.delta.stop_reason == "tool_use" then
|
elseif jsn.delta.stop_reason == "tool_use" then
|
||||||
opts.on_stop({
|
opts.on_stop({
|
||||||
reason = "tool_use",
|
reason = "tool_use",
|
||||||
|
|||||||
@@ -243,7 +243,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
||||||
ctx.function_call_id = ctx.function_call_id + 1
|
ctx.function_call_id = ctx.function_call_id + 1
|
||||||
local tool_use = {
|
local tool_use = {
|
||||||
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
|
id = ctx.turn_id .. "-" .. tostring(ctx.function_call_id),
|
||||||
name = part.functionCall.name,
|
name = part.functionCall.name,
|
||||||
input_json = vim.json.encode(part.functionCall.args),
|
input_json = vim.json.encode(part.functionCall.args),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -264,7 +264,6 @@ function M:add_text_message(ctx, text, state, opts)
|
|||||||
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
|
local cleaned_xml_content = table.concat(cleaned_xml_lines, "\n")
|
||||||
local stream_parser = XMLParser.createStreamParser()
|
local stream_parser = XMLParser.createStreamParser()
|
||||||
stream_parser:addData(cleaned_xml_content)
|
stream_parser:addData(cleaned_xml_content)
|
||||||
local has_tool_use = false
|
|
||||||
local xml = stream_parser:getAllElements()
|
local xml = stream_parser:getAllElements()
|
||||||
if xml then
|
if xml then
|
||||||
local new_content_list = {}
|
local new_content_list = {}
|
||||||
@@ -318,7 +317,7 @@ function M:add_text_message(ctx, text, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = msg_uuid,
|
uuid = msg_uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
msgs[#msgs + 1] = msg_
|
msgs[#msgs + 1] = msg_
|
||||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||||
@@ -327,14 +326,13 @@ function M:add_text_message(ctx, text, state, opts)
|
|||||||
name = item._name,
|
name = item._name,
|
||||||
input_json = input,
|
input_json = input,
|
||||||
}
|
}
|
||||||
has_tool_use = true
|
|
||||||
end
|
end
|
||||||
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
|
if #new_content_list > 0 then msg.message.content = table.concat(new_content_list, "\n") end
|
||||||
::continue::
|
::continue::
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
if opts.on_messages_add then opts.on_messages_add(msgs) end
|
||||||
if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
-- if has_tool_use and state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:add_thinking_message(ctx, text, state, opts)
|
function M:add_thinking_message(ctx, text, state, opts)
|
||||||
@@ -352,7 +350,7 @@ function M:add_thinking_message(ctx, text, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = ctx.reasonging_content_uuid,
|
uuid = ctx.reasonging_content_uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
ctx.reasonging_content_uuid = msg.uuid
|
ctx.reasonging_content_uuid = msg.uuid
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
@@ -373,38 +371,25 @@ function M:add_tool_use_message(ctx, tool_use, state, opts)
|
|||||||
}, {
|
}, {
|
||||||
state = state,
|
state = state,
|
||||||
uuid = tool_use.uuid,
|
uuid = tool_use.uuid,
|
||||||
session_id = ctx.session_id,
|
turn_id = ctx.turn_id,
|
||||||
})
|
})
|
||||||
tool_use.uuid = msg.uuid
|
tool_use.uuid = msg.uuid
|
||||||
tool_use.state = state
|
tool_use.state = state
|
||||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||||
if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
-- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M:parse_response(ctx, data_stream, _, opts)
|
function M:parse_response(ctx, data_stream, _, opts)
|
||||||
local orig_on_stop = opts.on_stop
|
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
|
||||||
local stopped = false
|
|
||||||
---@param stop_opts AvanteLLMStopCallbackOptions
|
|
||||||
opts.on_stop = function(stop_opts)
|
|
||||||
if stop_opts and not stop_opts.streaming_tool_use then
|
|
||||||
if stopped then return end
|
|
||||||
stopped = true
|
|
||||||
end
|
|
||||||
return orig_on_stop(stop_opts)
|
|
||||||
end
|
|
||||||
if data_stream:match('"%[DONE%]":') then
|
|
||||||
self:finish_pending_messages(ctx, opts)
|
self:finish_pending_messages(ctx, opts)
|
||||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||||
|
ctx.tool_use_list = {}
|
||||||
opts.on_stop({ reason = "tool_use" })
|
opts.on_stop({ reason = "tool_use" })
|
||||||
else
|
else
|
||||||
opts.on_stop({ reason = "complete" })
|
opts.on_stop({ reason = "complete" })
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
if data_stream == "[DONE]" then
|
|
||||||
opts.on_stop({ reason = "complete" })
|
|
||||||
return
|
|
||||||
end
|
|
||||||
local jsn = vim.json.decode(data_stream)
|
local jsn = vim.json.decode(data_stream)
|
||||||
---@cast jsn AvanteOpenAIChatResponse
|
---@cast jsn AvanteOpenAIChatResponse
|
||||||
if not jsn.choices then return end
|
if not jsn.choices then return end
|
||||||
@@ -453,7 +438,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
|||||||
else
|
else
|
||||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||||
self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
-- self:add_tool_use_message(ctx, tool_use, "generating", opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
elseif delta.content then
|
elseif delta.content then
|
||||||
|
|||||||
@@ -2003,7 +2003,6 @@ end
|
|||||||
function Sidebar:add_chat_history(messages, options)
|
function Sidebar:add_chat_history(messages, options)
|
||||||
options = options or {}
|
options = options or {}
|
||||||
messages = vim.islist(messages) and messages or { messages }
|
messages = vim.islist(messages) and messages or { messages }
|
||||||
self:reload_chat_history()
|
|
||||||
local is_first_user = true
|
local is_first_user = true
|
||||||
local history_messages = {}
|
local history_messages = {}
|
||||||
for _, message in ipairs(messages) do
|
for _, message in ipairs(messages) do
|
||||||
@@ -2191,44 +2190,6 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
history_messages0 = picked_messages
|
history_messages0 = picked_messages
|
||||||
end
|
end
|
||||||
|
|
||||||
local picked_messages = {}
|
|
||||||
local max_tool_use_count = 15
|
|
||||||
local tool_use_count = 0
|
|
||||||
for idx = #history_messages0, 1, -1 do
|
|
||||||
local msg = history_messages0[idx]
|
|
||||||
if tool_use_count > max_tool_use_count then
|
|
||||||
if Utils.is_tool_result_message(msg) then
|
|
||||||
local tool_use_message = Utils.get_tool_use_message(msg, history_messages0)
|
|
||||||
if tool_use_message then
|
|
||||||
local msg_content = {}
|
|
||||||
table.insert(
|
|
||||||
msg_content,
|
|
||||||
string.format(
|
|
||||||
"Tool use %s(%s)",
|
|
||||||
tool_use_message.message.content[1].name,
|
|
||||||
vim.json.encode(tool_use_message.message.content[1].input)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
table.insert(msg_content, string.format("Result: %s", msg.message.content[1].content))
|
|
||||||
table.insert(
|
|
||||||
picked_messages,
|
|
||||||
1,
|
|
||||||
HistoryMessage:new({ role = "user", content = msg_content }, { is_dummy = true })
|
|
||||||
)
|
|
||||||
end
|
|
||||||
elseif Utils.is_tool_use_message(history_messages0[idx]) then
|
|
||||||
tool_use_count = tool_use_count + 1
|
|
||||||
goto continue
|
|
||||||
else
|
|
||||||
table.insert(picked_messages, 1, msg)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if Utils.is_tool_use_message(history_messages0[idx]) then tool_use_count = tool_use_count + 1 end
|
|
||||||
table.insert(picked_messages, 1, msg)
|
|
||||||
end
|
|
||||||
::continue::
|
|
||||||
end
|
|
||||||
|
|
||||||
local tool_id_to_tool_name = {}
|
local tool_id_to_tool_name = {}
|
||||||
local tool_id_to_path = {}
|
local tool_id_to_path = {}
|
||||||
local tool_id_to_start_line = {}
|
local tool_id_to_start_line = {}
|
||||||
@@ -2241,6 +2202,7 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
for idx, message in ipairs(history_messages0) do
|
for idx, message in ipairs(history_messages0) do
|
||||||
if Utils.is_tool_result_message(message) then
|
if Utils.is_tool_result_message(message) then
|
||||||
local tool_use_message = Utils.get_tool_use_message(message, history_messages0)
|
local tool_use_message = Utils.get_tool_use_message(message, history_messages0)
|
||||||
|
|
||||||
local is_edit_func_call, _, _, path = Utils.is_edit_func_call_message(tool_use_message)
|
local is_edit_func_call, _, _, path = Utils.is_edit_func_call_message(tool_use_message)
|
||||||
|
|
||||||
local tool_result = message.message.content[1]
|
local tool_result = message.message.content[1]
|
||||||
@@ -2264,8 +2226,8 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
for idx, message in ipairs(history_messages0) do
|
for idx, message in ipairs(history_messages0) do
|
||||||
if Utils.is_tool_use_message(message) and failed_edit_tool_ids[message.message.content[1].id] then
|
if Utils.is_tool_use_message(message) then
|
||||||
goto continue
|
if failed_edit_tool_ids[message.message.content[1].id] then goto continue end
|
||||||
end
|
end
|
||||||
table.insert(history_messages, message)
|
table.insert(history_messages, message)
|
||||||
if Utils.is_tool_result_message(message) then
|
if Utils.is_tool_result_message(message) then
|
||||||
@@ -2422,6 +2384,62 @@ function Sidebar:get_history_messages_for_api(opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local picked_messages = {}
|
||||||
|
local max_tool_use_count = 10
|
||||||
|
local tool_use_count = 0
|
||||||
|
for idx = #history_messages, 1, -1 do
|
||||||
|
local msg = history_messages[idx]
|
||||||
|
if tool_use_count > max_tool_use_count then
|
||||||
|
if Utils.is_tool_result_message(msg) then
|
||||||
|
local tool_use_message = Utils.get_tool_use_message(msg, history_messages)
|
||||||
|
if tool_use_message then
|
||||||
|
table.insert(
|
||||||
|
picked_messages,
|
||||||
|
1,
|
||||||
|
HistoryMessage:new({
|
||||||
|
role = "user",
|
||||||
|
content = {
|
||||||
|
{
|
||||||
|
type = "text",
|
||||||
|
text = string.format(
|
||||||
|
"Tool use [%s] is successful: %s",
|
||||||
|
tool_use_message.message.content[1].name,
|
||||||
|
tostring(not msg.message.content[1].is_error)
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, { is_dummy = true })
|
||||||
|
)
|
||||||
|
local msg_content = {}
|
||||||
|
table.insert(msg_content, {
|
||||||
|
type = "text",
|
||||||
|
text = string.format(
|
||||||
|
"Tool use %s(%s)",
|
||||||
|
tool_use_message.message.content[1].name,
|
||||||
|
vim.json.encode(tool_use_message.message.content[1].input)
|
||||||
|
),
|
||||||
|
})
|
||||||
|
table.insert(
|
||||||
|
picked_messages,
|
||||||
|
1,
|
||||||
|
HistoryMessage:new({ role = "assistant", content = msg_content }, { is_dummy = true })
|
||||||
|
)
|
||||||
|
end
|
||||||
|
elseif Utils.is_tool_use_message(msg) then
|
||||||
|
tool_use_count = tool_use_count + 1
|
||||||
|
goto continue
|
||||||
|
else
|
||||||
|
table.insert(picked_messages, 1, msg)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if Utils.is_tool_use_message(msg) then tool_use_count = tool_use_count + 1 end
|
||||||
|
table.insert(picked_messages, 1, msg)
|
||||||
|
end
|
||||||
|
::continue::
|
||||||
|
end
|
||||||
|
|
||||||
|
history_messages = picked_messages
|
||||||
|
|
||||||
local final_history_messages = {}
|
local final_history_messages = {}
|
||||||
for _, msg in ipairs(history_messages) do
|
for _, msg in ipairs(history_messages) do
|
||||||
local tool_result_message
|
local tool_result_message
|
||||||
@@ -2704,9 +2722,6 @@ function Sidebar:create_input_container()
|
|||||||
Path.history.save(self.code.bufnr, self.chat_history)
|
Path.history.save(self.code.bufnr, self.chat_history)
|
||||||
end
|
end
|
||||||
|
|
||||||
local history_messages = Utils.get_history_messages(self.chat_history)
|
|
||||||
local is_first_request = #history_messages == 0
|
|
||||||
|
|
||||||
if request and request ~= "" then
|
if request and request ~= "" then
|
||||||
self:add_history_messages({
|
self:add_history_messages({
|
||||||
HistoryMessage:new({
|
HistoryMessage:new({
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field is_dummy boolean | nil
|
---@field is_dummy boolean | nil
|
||||||
---@field is_compacted boolean | nil
|
---@field is_compacted boolean | nil
|
||||||
---@field is_deleted boolean | nil
|
---@field is_deleted boolean | nil
|
||||||
---@field session_id string | nil
|
---@field turn_id string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteLLMToolResult
|
---@class AvanteLLMToolResult
|
||||||
---@field tool_name string
|
---@field tool_name string
|
||||||
@@ -278,7 +278,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field usage? AvanteLLMUsage
|
---@field usage? AvanteLLMUsage
|
||||||
---
|
---
|
||||||
---@class AvanteLLMStopCallbackOptions
|
---@class AvanteLLMStopCallbackOptions
|
||||||
---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled"
|
---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" | "max_tokens"
|
||||||
---@field error? string | table
|
---@field error? string | table
|
||||||
---@field usage? AvanteLLMUsage
|
---@field usage? AvanteLLMUsage
|
||||||
---@field retry_after? integer
|
---@field retry_after? integer
|
||||||
|
|||||||
@@ -993,10 +993,14 @@ function M.open_buffer(path, set_current_buf)
|
|||||||
|
|
||||||
local abs_path = M.join_paths(M.get_project_root(), path)
|
local abs_path = M.join_paths(M.get_project_root(), path)
|
||||||
|
|
||||||
local bufnr = vim.fn.bufnr(abs_path, true)
|
local bufnr
|
||||||
vim.fn.bufload(bufnr)
|
if set_current_buf then
|
||||||
|
vim.cmd("noautocmd edit " .. abs_path)
|
||||||
if set_current_buf then vim.api.nvim_set_current_buf(bufnr) end
|
bufnr = vim.api.nvim_get_current_buf()
|
||||||
|
else
|
||||||
|
bufnr = vim.fn.bufnr(abs_path, true)
|
||||||
|
pcall(vim.fn.bufload, bufnr)
|
||||||
|
end
|
||||||
|
|
||||||
vim.cmd("filetype detect")
|
vim.cmd("filetype detect")
|
||||||
|
|
||||||
@@ -1480,10 +1484,6 @@ function M.is_edit_func_call_tool_use(tool_use)
|
|||||||
local is_str_replace_editor_func_call = false
|
local is_str_replace_editor_func_call = false
|
||||||
local is_str_replace_based_edit_tool_func_call = false
|
local is_str_replace_based_edit_tool_func_call = false
|
||||||
local path = nil
|
local path = nil
|
||||||
if tool_use.name == "write_to_file" then
|
|
||||||
is_replace_func_call = true
|
|
||||||
path = tool_use.input.path
|
|
||||||
end
|
|
||||||
if tool_use.name == "replace_in_file" then
|
if tool_use.name == "replace_in_file" then
|
||||||
is_replace_func_call = true
|
is_replace_func_call = true
|
||||||
path = tool_use.input.path
|
path = tool_use.input.path
|
||||||
@@ -1711,10 +1711,17 @@ end
|
|||||||
---@param history_messages avante.HistoryMessage[]
|
---@param history_messages avante.HistoryMessage[]
|
||||||
---@return AvantePartialLLMToolUse[]
|
---@return AvantePartialLLMToolUse[]
|
||||||
function M.get_uncalled_tool_uses(history_messages)
|
function M.get_uncalled_tool_uses(history_messages)
|
||||||
local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
|
local last_turn_id = nil
|
||||||
|
if #history_messages > 0 then last_turn_id = history_messages[#history_messages].turn_id end
|
||||||
|
local uncalled_tool_use_list = {} ---@type AvantePartialLLMToolUse[]
|
||||||
local tool_result_seen = {}
|
local tool_result_seen = {}
|
||||||
for idx = #history_messages, 1, -1 do
|
for idx = #history_messages, 1, -1 do
|
||||||
local message = history_messages[idx]
|
local message = history_messages[idx]
|
||||||
|
if last_turn_id then
|
||||||
|
if message.turn_id ~= last_turn_id then break end
|
||||||
|
else
|
||||||
|
if not M.is_tool_use_message(message) and not M.is_tool_result_message(message) then break end
|
||||||
|
end
|
||||||
local content = message.message.content
|
local content = message.message.content
|
||||||
if type(content) ~= "table" or #content == 0 then goto continue end
|
if type(content) ~= "table" or #content == 0 then goto continue end
|
||||||
local is_break = false
|
local is_break = false
|
||||||
@@ -1727,7 +1734,7 @@ function M.get_uncalled_tool_uses(history_messages)
|
|||||||
input = item.input,
|
input = item.input,
|
||||||
state = message.state,
|
state = message.state,
|
||||||
}
|
}
|
||||||
table.insert(partial_tool_use_list, 1, partial_tool_use)
|
table.insert(uncalled_tool_use_list, 1, partial_tool_use)
|
||||||
else
|
else
|
||||||
is_break = true
|
is_break = true
|
||||||
break
|
break
|
||||||
@@ -1738,7 +1745,7 @@ function M.get_uncalled_tool_uses(history_messages)
|
|||||||
if is_break then break end
|
if is_break then break end
|
||||||
::continue::
|
::continue::
|
||||||
end
|
end
|
||||||
return partial_tool_use_list
|
return uncalled_tool_use_list
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.call_once(func)
|
function M.call_once(func)
|
||||||
|
|||||||
Reference in New Issue
Block a user