feat: tokens usage (#2300)
This commit is contained in:
@@ -133,6 +133,18 @@ function M:parse_messages(opts)
|
||||
return messages
|
||||
end
|
||||
|
||||
---@param usage avante.AnthropicTokenUsage | nil
|
||||
---@return avante.LLMTokenUsage | nil
|
||||
function M.transform_anthropic_usage(usage)
|
||||
if not usage then return nil end
|
||||
---@type avante.LLMTokenUsage
|
||||
local res = {
|
||||
prompt_tokens = usage.input_tokens + usage.cache_creation_input_tokens,
|
||||
completion_tokens = usage.output_tokens + usage.cache_read_input_tokens,
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if event_state == nil then
|
||||
if data_stream:match('"message_start"') then
|
||||
@@ -153,7 +165,7 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if event_state == "message_start" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
if not ok then return end
|
||||
opts.on_start(jsn.message.usage)
|
||||
ctx.usage = jsn.message.usage
|
||||
elseif event_state == "content_block_start" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
if not ok then return end
|
||||
@@ -315,14 +327,15 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
elseif event_state == "message_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
if not ok then return end
|
||||
if jsn.usage and ctx.usage then ctx.usage.output_tokens = ctx.usage.output_tokens + jsn.usage.output_tokens end
|
||||
if jsn.delta.stop_reason == "end_turn" then
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
opts.on_stop({ reason = "complete", usage = self.transform_anthropic_usage(ctx.usage) })
|
||||
elseif jsn.delta.stop_reason == "max_tokens" then
|
||||
opts.on_stop({ reason = "max_tokens", usage = jsn.usage })
|
||||
opts.on_stop({ reason = "max_tokens", usage = self.transform_anthropic_usage(ctx.usage) })
|
||||
elseif jsn.delta.stop_reason == "tool_use" then
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
usage = jsn.usage,
|
||||
usage = self.transform_anthropic_usage(ctx.usage),
|
||||
})
|
||||
end
|
||||
return
|
||||
|
||||
@@ -209,16 +209,33 @@ function M.prepare_request_body(provider_instance, prompt_opts, provider_conf, r
|
||||
return vim.tbl_deep_extend("force", {}, provider_instance:parse_messages(prompt_opts), request_body)
|
||||
end
|
||||
|
||||
---@param usage avante.GeminiTokenUsage | nil
|
||||
---@return avante.LLMTokenUsage | nil
|
||||
function M.transform_gemini_usage(usage)
|
||||
if not usage then return nil end
|
||||
---@type avante.LLMTokenUsage
|
||||
local res = {
|
||||
prompt_tokens = usage.promptTokenCount,
|
||||
completion_tokens = usage.candidatesTokenCount,
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local ok, json = pcall(vim.json.decode, data_stream)
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
if not ok then
|
||||
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(json) })
|
||||
opts.on_stop({ reason = "error", error = "Failed to parse JSON response: " .. tostring(jsn) })
|
||||
return
|
||||
end
|
||||
|
||||
if opts.update_tokens_usage and jsn.usageMetadata and jsn.usageMetadata ~= nil then
|
||||
local usage = M.transform_gemini_usage(jsn.usageMetadata)
|
||||
if usage ~= nil then opts.update_tokens_usage(usage) end
|
||||
end
|
||||
|
||||
-- Handle prompt feedback first, as it might indicate an overall issue with the prompt
|
||||
if json.promptFeedback and json.promptFeedback.blockReason then
|
||||
local feedback = json.promptFeedback
|
||||
if jsn.promptFeedback and jsn.promptFeedback.blockReason then
|
||||
local feedback = jsn.promptFeedback
|
||||
OpenAI:finish_pending_messages(ctx, opts) -- Ensure any pending messages are cleared
|
||||
opts.on_stop({
|
||||
reason = "error",
|
||||
@@ -228,8 +245,8 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
return
|
||||
end
|
||||
|
||||
if json.candidates and #json.candidates > 0 then
|
||||
local candidate = json.candidates[1]
|
||||
if jsn.candidates and #jsn.candidates > 0 then
|
||||
local candidate = jsn.candidates[1]
|
||||
---@type AvanteLLMToolUse[]
|
||||
ctx.tool_use_list = ctx.tool_use_list or {}
|
||||
|
||||
@@ -258,6 +275,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
local reason_str = candidate.finishReason
|
||||
local stop_details = { finish_reason = reason_str }
|
||||
stop_details.usage = M.transform_gemini_usage(jsn.usageMetadata)
|
||||
|
||||
if reason_str == "TOOL_CODE" then
|
||||
-- Model indicates a tool-related stop.
|
||||
|
||||
@@ -379,6 +379,19 @@ function M:add_tool_use_message(ctx, tool_use, state, opts)
|
||||
-- if state == "generating" then opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) end
|
||||
end
|
||||
|
||||
---@param usage avante.OpenAITokenUsage | nil
|
||||
---@return avante.LLMTokenUsage | nil
|
||||
function M.transform_openai_usage(usage)
|
||||
if not usage then return nil end
|
||||
if usage == vim.NIL then return nil end
|
||||
---@type avante.LLMTokenUsage
|
||||
local res = {
|
||||
prompt_tokens = usage.prompt_tokens,
|
||||
completion_tokens = usage.completion_tokens,
|
||||
}
|
||||
return res
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') or data_stream == "[DONE]" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
@@ -391,6 +404,12 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
return
|
||||
end
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if jsn.usage and jsn.usage ~= vim.NIL then
|
||||
if opts.update_tokens_usage then
|
||||
local usage = self.transform_openai_usage(jsn.usage)
|
||||
if usage then opts.update_tokens_usage(usage) end
|
||||
end
|
||||
end
|
||||
---@cast jsn AvanteOpenAIChatResponse
|
||||
if not jsn.choices then return end
|
||||
local choice = jsn.choices[1]
|
||||
@@ -463,16 +482,16 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" or choice.finish_reason == "length" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
if ctx.tool_use_list and #ctx.tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", usage = jsn.usage })
|
||||
opts.on_stop({ reason = "tool_use", usage = self.transform_openai_usage(jsn.usage) })
|
||||
else
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
opts.on_stop({ reason = "complete", usage = self.transform_openai_usage(jsn.usage) })
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
usage = jsn.usage,
|
||||
usage = self.transform_openai_usage(jsn.usage),
|
||||
})
|
||||
end
|
||||
end
|
||||
@@ -536,6 +555,9 @@ function M:parse_curl_args(prompt_opts)
|
||||
model = provider_conf.model,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
stream_options = {
|
||||
include_usage = true,
|
||||
},
|
||||
tools = tools,
|
||||
}, request_body),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user