diff --git a/lua/avante/history_message.lua b/lua/avante/history_message.lua index e9e1340..4d1e6d0 100644 --- a/lua/avante/history_message.lua +++ b/lua/avante/history_message.lua @@ -5,7 +5,7 @@ local M = {} M.__index = M ---@param message AvanteLLMMessage ----@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean} +---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean, is_dummy?: boolean, session_id?: string} ---@return avante.HistoryMessage function M:new(message, opts) opts = opts or {} @@ -23,6 +23,7 @@ function M:new(message, opts) if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end if opts.is_dummy ~= nil then obj.is_dummy = opts.is_dummy end + obj.session_id = opts.session_id return obj end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index fa3d448..df04054 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -166,6 +166,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) content = content_block.text, }, { state = "generating", + session_id = ctx.session_id, }) content_block.uuid = msg.uuid if opts.on_messages_add then opts.on_messages_add({ msg }) end @@ -184,6 +185,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, }, { state = "generating", + session_id = ctx.session_id, }) content_block.uuid = msg.uuid opts.on_messages_add({ msg }) @@ -203,6 +205,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, }, { state = "generating", + session_id = ctx.session_id, }) content_block.uuid = msg.uuid opts.on_messages_add({ msg }) @@ -231,6 +234,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generating", uuid = content_block.uuid, + session_id = ctx.session_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "text_delta" then @@ -242,6 +246,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generating", uuid = content_block.uuid, + session_id = ctx.session_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end elseif jsn.delta.type == "signature_delta" then @@ -260,6 +265,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, + session_id = ctx.session_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end @@ -278,6 +284,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, + session_id = ctx.session_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end @@ -301,6 +308,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) }, { state = "generated", uuid = content_block.uuid, + session_id = ctx.session_id, }) if opts.on_messages_add then opts.on_messages_add({ msg }) end end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 866ece6..13fbcf5 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -248,7 +248,7 @@ function M:parse_response(ctx, data_stream, _, opts) input_json = vim.json.encode(part.functionCall.args), } table.insert(ctx.tool_use_list, tool_use) - OpenAI:add_tool_use_message(tool_use, "generated", opts) + OpenAI:add_tool_use_message(ctx, tool_use, "generated", opts) end end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 61145be..b79dcac 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -215,7 +215,7 @@ function M:finish_pending_messages(ctx, opts) if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end if ctx.tool_use_list then for _, tool_use in ipairs(ctx.tool_use_list) do - if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end + if tool_use.state == "generating" then self:add_tool_use_message(ctx, tool_use, "generated", opts) end end end end @@ -318,6 +318,7 @@ function M:add_text_message(ctx, text, state, opts) }, { state = state, uuid = msg_uuid, + session_id = ctx.session_id, }) msgs[#msgs + 1] = msg_ ctx.tool_use_list = ctx.tool_use_list or {} @@ -351,12 +352,13 @@ function M:add_thinking_message(ctx, text, state, opts) }, { state = state, uuid = ctx.reasonging_content_uuid, + session_id = ctx.session_id, }) ctx.reasonging_content_uuid = msg.uuid if opts.on_messages_add then opts.on_messages_add({ msg }) end end -function M:add_tool_use_message(tool_use, state, opts) +function M:add_tool_use_message(ctx, tool_use, state, opts) local jsn = JsonParser.parse(tool_use.input_json) local msg = HistoryMessage:new({ role = "assistant", @@ -371,6 +373,7 @@ function M:add_tool_use_message(tool_use, state, opts) }, { state = state, uuid = tool_use.uuid, + session_id = ctx.session_id, }) tool_use.uuid = msg.uuid tool_use.state = state @@ -438,7 +441,7 @@ function M:parse_response(ctx, data_stream, _, opts) if not ctx.tool_use_list[tool_call.index + 1] then if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then local prev_tool_use = ctx.tool_use_list[tool_call.index] - self:add_tool_use_message(prev_tool_use, "generated", opts) + self:add_tool_use_message(ctx, prev_tool_use, "generated", opts) end local tool_use = { name = tool_call["function"].name, @@ -446,11 +449,11 @@ function M:parse_response(ctx, data_stream, _, opts) input_json = type(tool_call["function"].arguments) == "string" and tool_call["function"].arguments or "", } ctx.tool_use_list[tool_call.index + 1] = tool_use - self:add_tool_use_message(tool_use, "generating", opts) + self:add_tool_use_message(ctx, tool_use, "generating", opts) else local tool_use = ctx.tool_use_list[tool_call.index + 1] tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments - self:add_tool_use_message(tool_use, "generating", opts) + self:add_tool_use_message(ctx, tool_use, "generating", opts) end end elseif delta.content then diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index d7fb6bb..b48d5aa 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2176,6 +2176,11 @@ function Sidebar:get_history_messages_for_api(opts) if opts.all then return history_messages0 end + history_messages0 = vim + .iter(history_messages0) + :filter(function(message) return message.state ~= "generating" end) + :totable() + if self.chat_history and self.chat_history.memory then local picked_messages = {} for idx = #history_messages0, 1, -1 do @@ -2186,6 +2191,44 @@ function Sidebar:get_history_messages_for_api(opts) history_messages0 = picked_messages end + local picked_messages = {} + local max_tool_use_count = 15 + local tool_use_count = 0 + for idx = #history_messages0, 1, -1 do + local msg = history_messages0[idx] + if tool_use_count > max_tool_use_count then + if Utils.is_tool_result_message(msg) then + local tool_use_message = Utils.get_tool_use_message(msg, history_messages0) + if tool_use_message then + local msg_content = {} + table.insert( + msg_content, + string.format( + "Tool use %s(%s)", + tool_use_message.message.content[1].name, + vim.json.encode(tool_use_message.message.content[1].input) + ) + ) + table.insert(msg_content, string.format("Result: %s", msg.message.content[1].content)) + table.insert( + picked_messages, + 1, + HistoryMessage:new({ role = "user", content = msg_content }, { is_dummy = true }) + ) + end + elseif Utils.is_tool_use_message(history_messages0[idx]) then + tool_use_count = tool_use_count + 1 + goto continue + else + table.insert(picked_messages, 1, msg) + end + else + if Utils.is_tool_use_message(history_messages0[idx]) then tool_use_count = tool_use_count + 1 end + table.insert(picked_messages, 1, msg) + end + ::continue:: + end + local tool_id_to_tool_name = {} local tool_id_to_path = {} local tool_id_to_start_line = {} diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 7f975be..76ca830 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -110,6 +110,7 @@ vim.g.avante_login = vim.g.avante_login ---@field is_dummy boolean | nil ---@field is_compacted boolean | nil ---@field is_deleted boolean | nil +---@field session_id string | nil --- ---@class AvanteLLMToolResult ---@field tool_name string