feat: tokens usage (#2300)

This commit is contained in:
yetone
2025-06-23 03:13:37 +08:00
committed by GitHub
parent 7daf169228
commit 6830f2d8b9
7 changed files with 131 additions and 41 deletions

View File

@@ -133,6 +133,18 @@ function M:parse_messages(opts)
return messages
end
---@param usage avante.AnthropicTokenUsage | nil
---@return avante.LLMTokenUsage | nil
function M.transform_anthropic_usage(usage)
if not usage then return nil end
---@type avante.LLMTokenUsage
local res = {
prompt_tokens = usage.input_tokens + usage.cache_creation_input_tokens,
completion_tokens = usage.output_tokens + usage.cache_read_input_tokens,
}
return res
end
function M:parse_response(ctx, data_stream, event_state, opts)
if event_state == nil then
if data_stream:match('"message_start"') then
@@ -153,7 +165,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
if event_state == "message_start" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
opts.on_start(jsn.message.usage)
ctx.usage = jsn.message.usage
elseif event_state == "content_block_start" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
@@ -315,14 +327,15 @@ function M:parse_response(ctx, data_stream, event_state, opts)
elseif event_state == "message_delta" then
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then return end
if jsn.usage and ctx.usage then ctx.usage.output_tokens = ctx.usage.output_tokens + jsn.usage.output_tokens end
if jsn.delta.stop_reason == "end_turn" then
opts.on_stop({ reason = "complete", usage = jsn.usage })
opts.on_stop({ reason = "complete", usage = self.transform_anthropic_usage(ctx.usage) })
elseif jsn.delta.stop_reason == "max_tokens" then
opts.on_stop({ reason = "max_tokens", usage = jsn.usage })
opts.on_stop({ reason = "max_tokens", usage = self.transform_anthropic_usage(ctx.usage) })
elseif jsn.delta.stop_reason == "tool_use" then
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
usage = self.transform_anthropic_usage(ctx.usage),
})
end
return

View File

@@ -209,16 +209,33 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r
return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body)
end
---@param usage avante.GeminiTokenUsage | nil
---@return avante.LLMTokenUsage | nil
function M.transform_gemini_usage(usage)
if not usage then return nil end
---@type avante.LLMTokenUsage
local res = {
prompt_tokens = usage.promptTokenCount,
completion_tokens = usage.candidatesTokenCount,
}
return res
end
function M:parse_response(ctx, data_stream, _, opts)
local ok, json = pcall(vim.json.decode, data_stream)
local ok, jsn = pcall(vim.json.decode, data_stream)
if not ok then
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(json) })
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) })
return
end
if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then
local usage = M.transform_gemini_usage(jsn.usageMetadata)
if usage ~= nil then opts.update_tokens_usage(usage) end
end
-- Handle prompt feedback first, as it might indicate an overall issue with the prompt
if json.promptFeedback and json.promptFeedback.blockReason then
local feedback = json.promptFeedback
if jsn.promptFeedback and jsn.promptFeedback.blockReason then
local feedback = jsn.promptFeedback
OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared
opts.on_stop({
reason = "error",
@@ -228,8 +245,8 @@ function M:parse_response(ctx, data_stream, _, opts)
return
end
if json.candidates and #json.candidates > 0 then
local candidate = json.candidates[1]
if jsn.candidates and #jsn.candidates > 0 then
local candidate = jsn.candidates[1]
---@type AvanteLLMToolUse[]
ctx.tool_use_list = ctx.tool_use_list or {}
@@ -258,6 +275,7 @@ function M:parse_response(ctx, data_stream, _, opts)
OpenAI:finish_pending_messages(ctx, opts)
local reason_str = candidate.finishReason
local stop_details = { finish_reason = reason_str }
stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata)
if reason_str == "TOOL_CODE" then
-- Model indicates a tool-related stop.

View File

@@ -379,6 +379,19 @@ function M:add_tool_use_message(ctx, tool_use, state, opts)
-- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
end
---@param usage avante.OpenAITokenUsage | nil
---@return avante.LLMTokenUsage | nil
function M.transform_openai_usage(usage)
if not usage then return nil end
if usage == vim.NIL then return nil end
---@type avante.LLMTokenUsage
local res = {
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
}
return res
end
function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
self:finish_pending_messages(ctx, opts)
@@ -391,6 +404,12 @@ function M:parse_response(ctx, data_stream, _, opts)
return
end
local jsn = vim.json.decode(data_stream)
if jsn.usage and jsn.usage ~= vim.NIL then
if opts.update_tokens_usage then
local usage = self.transform_openai_usage(jsn.usage)
if usage then opts.update_tokens_usage(usage) end
end
end
---@cast jsn AvanteOpenAIChatResponse
if not jsn.choices then return end
local choice = jsn.choices[1]
@@ -463,16 +482,16 @@ function M:parse_response(ctx, data_stream, _, opts)
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
self:finish_pending_messages(ctx, opts)
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) })
else
opts.on_stop({ reason = "complete", usage = jsn.usage })
opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) })
end
end
if choice.finish_reason == "tool_calls" then
self:finish_pending_messages(ctx, opts)
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
usage = self.transform_openai_usage(jsn.usage),
})
end
end
@@ -536,6 +555,9 @@ function M:parse_curl_args(prompt_opts)
model = provider_conf.model,
messages = self:parse_messages(prompt_opts),
stream = true,
stream_options = {
include_usage = true,
},
tools = tools,
}, request_body),
}