diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 3c44b96..2a15385 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -237,6 +237,7 @@ M._defaults = { endpoint = "https://api.openai.com/v1", model = "gpt-4o", timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models + context_window = 128000, -- Number of tokens to send to the model for context extra_request_body = { temperature = 0.75, max_completion_tokens = 16384, -- Increase this to include reasoning tokens (for reasoning models) @@ -250,6 +251,7 @@ M._defaults = { proxy = nil, -- [protocol://]host[:port] Use this proxy allow_insecure = false, -- Allow insecure server connections timeout = 30000, -- Timeout in milliseconds + context_window = 128000, -- Number of tokens to send to the model for context extra_request_body = { temperature = 0.75, max_tokens = 20480, @@ -294,6 +296,7 @@ M._defaults = { endpoint = "https://generativelanguage.googleapis.com/v1beta/models", model = "gemini-2.0-flash", timeout = 30000, -- Timeout in milliseconds + context_window = 1048576, extra_request_body = { generationConfig = { temperature = 0.75, diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 5b70da4..87c63ce 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -69,18 +69,22 @@ function M.summarize_memory(prev_memory, history_messages, cb) cb(nil) return end - local latest_timestamp = history_messages[#history_messages].timestamp - local latest_message_uuid = history_messages[#history_messages].uuid + local latest_timestamp = nil + local latest_message_uuid = nil + for idx = #history_messages, 1, -1 do + local message = history_messages[idx] + if not message.is_dummy then + latest_timestamp = message.timestamp + latest_message_uuid = message.uuid + break + end + end + if not latest_timestamp or not latest_message_uuid then + cb(nil) + return + end local conversation_items = vim .iter(history_messages) - :filter(function(msg) - if msg.just_for_display then return false end - if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end - local content = msg.message.content - if type(content) == "table" and content[1].type == "tool_result" then return false end - if type(content) == "table" and content[1].type == "tool_use" then return false end - return true - end) :map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg, history_messages) end) :totable() local conversation_text = table.concat(conversation_items, "\n") @@ -200,9 +204,6 @@ end function M.generate_prompts(opts) local provider = opts.provider or Providers[Config.provider] local mode = opts.mode or Config.mode - ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor - local _, request_body = Providers.parse_config(provider) - local max_tokens = request_body.max_tokens or 4096 -- Check if the instructions contains an image path local image_paths = {} @@ -322,18 +323,27 @@ function M.generate_prompts(opts) end end - local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt) - - for _, message in ipairs(context_messages) do - remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content) - end - local pending_compaction_history_messages = {} if opts.prompt_opts and opts.prompt_opts.pending_compaction_history_messages then pending_compaction_history_messages = vim.list_extend(pending_compaction_history_messages, opts.prompt_opts.pending_compaction_history_messages) end + local context_window = provider.context_window + + if context_window and context_window > 0 then + Utils.debug("Context window", context_window) + if opts.get_tokens_usage then + local tokens_usage = opts.get_tokens_usage() + if tokens_usage then + local target_tokens = context_window * 0.9 + local tokens_count = tokens_usage.prompt_tokens + tokens_usage.completion_tokens + Utils.debug("Tokens count", tokens_count) + if tokens_count > target_tokens then pending_compaction_history_messages = opts.history_messages end + end + end + end + ---@type AvanteLLMMessage[] local messages = vim.deepcopy(context_messages) for _, msg in ipairs(opts.history_messages or {}) do @@ -674,9 +684,12 @@ function M._stream(opts) local handler_opts = { on_messages_add = opts.on_messages_add, on_state_change = opts.on_state_change, + update_tokens_usage = opts.update_tokens_usage, on_start = opts.on_start, on_chunk = opts.on_chunk, on_stop = function(stop_opts) + if stop_opts.usage and opts.update_tokens_usage then opts.update_tokens_usage(stop_opts.usage) end + ---@param partial_tool_use_list AvantePartialLLMToolUse[] ---@param tool_use_index integer ---@param tool_results AvanteLLMToolResult[] diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 1b3f69b..4253bb9 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -133,6 +133,18 @@ function M:parse_messages(opts) return messages end +---@param usage avante.AnthropicTokenUsage | nil +---@return avante.LLMTokenUsage | nil +function M.transform_anthropic_usage(usage) + if not usage then return nil end + ---@type avante.LLMTokenUsage + local res = { + prompt_tokens = usage.input_tokens + usage.cache_creation_input_tokens, + completion_tokens = usage.output_tokens + usage.cache_read_input_tokens, + } + return res +end + function M:parse_response(ctx, data_stream, event_state, opts) if event_state == nil then if data_stream:match('"message_start"') then @@ -153,7 +165,7 @@ function M:parse_response(ctx, data_stream, event_state, opts) if event_state == "message_start" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end - opts.on_start(jsn.message.usage) + ctx.usage = jsn.message.usage elseif event_state == "content_block_start" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end @@ -315,14 +327,15 @@ function M:parse_response(ctx, data_stream, event_state, opts) elseif event_state == "message_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end + if jsn.usage and ctx.usage then ctx.usage.output_tokens = ctx.usage.output_tokens + jsn.usage.output_tokens end if jsn.delta.stop_reason == "end_turn" then - opts.on_stop({ reason = "complete", usage = jsn.usage }) + opts.on_stop({ reason = "complete", usage = self.transform_anthropic_usage(ctx.usage) }) elseif jsn.delta.stop_reason == "max_tokens" then - opts.on_stop({ reason = "max_tokens", usage = jsn.usage }) + opts.on_stop({ reason = "max_tokens", usage = self.transform_anthropic_usage(ctx.usage) }) elseif jsn.delta.stop_reason == "tool_use" then opts.on_stop({ reason = "tool_use", - usage = jsn.usage, + usage = self.transform_anthropic_usage(ctx.usage), }) end return diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 9799d63..de195ac 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -209,16 +209,33 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body) end +---@param usage avante.GeminiTokenUsage | nil +---@return avante.LLMTokenUsage | nil +function M.transform_gemini_usage(usage) + if not usage then return nil end + ---@type avante.LLMTokenUsage + local res = { + prompt_tokens = usage.promptTokenCount, + completion_tokens = usage.candidatesTokenCount, + } + return res +end + function M:parse_response(ctx, data_stream, _, opts) - local ok, json = pcall(vim.json.decode, data_stream) + local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then - opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(json) }) + opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) }) return end + if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then + local usage = M.transform_gemini_usage(jsn.usageMetadata) + if usage ~= nil then opts.update_tokens_usage(usage) end + end + -- Handle prompt feedback first, as it might indicate an overall issue with the prompt - if json.promptFeedback and json.promptFeedback.blockReason then - local feedback = json.promptFeedback + if jsn.promptFeedback and jsn.promptFeedback.blockReason then + local feedback = jsn.promptFeedback OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared opts.on_stop({ reason = "error", @@ -228,8 +245,8 @@ function M:parse_response(ctx, data_stream, _, opts) return end - if json.candidates and #json.candidates > 0 then - local candidate = json.candidates[1] + if jsn.candidates and #jsn.candidates > 0 then + local candidate = jsn.candidates[1] ---@type AvanteLLMToolUse[] ctx.tool_use_list = ctx.tool_use_list or {} @@ -258,6 +275,7 @@ function M:parse_response(ctx, data_stream, _, opts) OpenAI:finish_pending_messages(ctx, opts) local reason_str = candidate.finishReason local stop_details = { finish_reason = reason_str } + stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata) if reason_str == "TOOL_CODE" then -- Model indicates a tool-related stop. diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index d4b7c64..d511b01 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -379,6 +379,19 @@ function M:add_tool_use_message(ctx, tool_use, state, opts) -- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end end +---@param usage avante.OpenAITokenUsage | nil +---@return avante.LLMTokenUsage | nil +function M.transform_openai_usage(usage) + if not usage then return nil end + if usage == vim.NIL then return nil end + ---@type avante.LLMTokenUsage + local res = { + prompt_tokens = usage.prompt_tokens, + completion_tokens = usage.completion_tokens, + } + return res +end + function M:parse_response(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then self:finish_pending_messages(ctx, opts) @@ -391,6 +404,12 @@ function M:parse_response(ctx, data_stream, _, opts) return end local jsn = vim.json.decode(data_stream) + if jsn.usage and jsn.usage ~= vim.NIL then + if opts.update_tokens_usage then + local usage = self.transform_openai_usage(jsn.usage) + if usage then opts.update_tokens_usage(usage) end + end + end ---@cast jsn AvanteOpenAIChatResponse if not jsn.choices then return end local choice = jsn.choices[1] @@ -463,16 +482,16 @@ function M:parse_response(ctx, data_stream, _, opts) if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then self:finish_pending_messages(ctx, opts) if ctx.tool_use_list and #ctx.tool_use_list > 0 then - opts.on_stop({ reason = "tool_use", usage = jsn.usage }) + opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) }) else - opts.on_stop({ reason = "complete", usage = jsn.usage }) + opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) }) end end if choice.finish_reason == "tool_calls" then self:finish_pending_messages(ctx, opts) opts.on_stop({ reason = "tool_use", - usage = jsn.usage, + usage = self.transform_openai_usage(jsn.usage), }) end end @@ -536,6 +555,9 @@ function M:parse_curl_args(prompt_opts) model = provider_conf.model, messages = self:parse_messages(prompt_opts), stream = true, + stream_options = { + include_usage = true, + }, tools = tools, }, request_body), } diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 9347057..12fde4e 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2386,7 +2386,7 @@ function Sidebar:get_history_messages_for_api(opts) end local picked_messages = {} - local max_tool_use_count = 10 + local max_tool_use_count = 25 local tool_use_count = 0 for idx = #history_messages, 1, -1 do local msg = history_messages[idx] @@ -2753,6 +2753,11 @@ function Sidebar:create_input_container() return history and history.todos or {} end, session_ctx = {}, + update_tokens_usage = function(usage) + self.chat_history.tokens_usage = usage + self:save_history() + end, + get_tokens_usage = function() return self.chat_history.tokens_usage end, }) ---@param pending_compaction_history_messages avante.HistoryMessage[] diff --git a/lua/avante/types.lua b/lua/avante/types.lua index db90670..ba4be87 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -77,6 +77,7 @@ vim.g.avante_login = vim.g.avante_login ---@field on_stop AvanteLLMStopCallback ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_state_change? fun(state: avante.GenerateState): nil +---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil --- ---@alias AvanteLLMMessageContentItem string | { type: "text", text: string, cache_control: { type: string } | nil } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } @@ -242,21 +243,32 @@ vim.g.avante_login = vim.g.avante_login ---@field entra? boolean ---@field hide_in_model_selector? boolean ---@field use_ReAct_prompt? boolean +---@field context_window? integer --- ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider ---@field __inherited_from? string ----@field temperature? number ----@field max_tokens? number ----@field max_completion_tokens? number ----@field reasoning_effort? string ---@field display_name? string --- ----@class AvanteLLMUsage +---@class avante.OpenAITokenUsage +---@field total_tokens number +---@field prompt_tokens number +---@field completion_tokens number +---@field prompt_tokens_details {cached_tokens: number} +--- +---@class avante.AnthropicTokenUsage ---@field input_tokens number ---@field cache_creation_input_tokens number ---@field cache_read_input_tokens number ---@field output_tokens number --- +---@class avante.GeminiTokenUsage +---@field promptTokenCount number +---@field candidatesTokenCount number +--- +---@class avante.LLMTokenUsage +---@field prompt_tokens number +---@field completion_tokens number +--- ---@class AvanteLLMThinkingBlock ---@field thinking string ---@field signature string @@ -275,12 +287,12 @@ vim.g.avante_login = vim.g.avante_login ---@field state avante.HistoryMessageState --- ---@class AvanteLLMStartCallbackOptions ----@field usage? AvanteLLMUsage +---@field usage? avante.LLMTokenUsage --- ---@class AvanteLLMStopCallbackOptions ----@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" | "max_tokens" +---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled" | "max_tokens" | "usage" ---@field error? string | table ----@field usage? AvanteLLMUsage +---@field usage? avante.LLMTokenUsage ---@field retry_after? integer ---@field headers? table ---@field streaming_tool_use? boolean @@ -323,6 +335,7 @@ vim.g.avante_login = vim.g.avante_login ---@field api_key_name string ---@field tokenizer_id string | "gpt-4o" ---@field model? string +---@field context_window? integer ---@field parse_api_key fun(): string | nil ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil @@ -366,6 +379,7 @@ vim.g.avante_login = vim.g.avante_login ---@field history_messages avante.HistoryMessage[] | nil ---@field get_todos? fun(): avante.TODO[] ---@field memory string | nil +---@field get_tokens_usage? fun(): avante.LLMTokenUsage | nil --- ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions ---@field instructions? string @@ -395,6 +409,7 @@ vim.g.avante_login = vim.g.avante_login ---@field get_history_messages? fun(opts?: { all?: boolean }): avante.HistoryMessage[] ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_state_change? fun(state: avante.GenerateState): nil +---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil --- ---@alias AvanteLLMToolFunc fun( --- input: T, @@ -461,6 +476,7 @@ vim.g.avante_login = vim.g.avante_login ---@field memory avante.ChatMemory | nil ---@field filename string ---@field system_prompt string | nil +---@field tokens_usage avante.LLMTokenUsage | nil --- ---@class avante.ChatMemory ---@field content string