feat(tokens): add token count display to sidebar (#956)

* feat (tokens) add token count display to sidebar

* refactor: calculate the real tokens and reuse input hints to avoid occlusion

---------

Co-authored-by: yetone <yetoneful@gmail.com>
This commit is contained in:
Michael Gendy
2024-12-17 14:43:25 +02:00
committed by GitHub
parent e612ad7566
commit e98fa46bec
6 changed files with 133 additions and 93 deletions

View File

@@ -18,10 +18,10 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"
local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param opts StreamOptions
---@param Provider AvanteProviderFunctor
M._stream = function(opts, Provider)
-- print opts
---@param opts GeneratePromptsOptions
---@return AvantePromptOptions
M.generate_prompts = function(opts)
local Provider = opts.provider or P[Config.provider]
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor
local _, body_opts = P.parse_config(Provider)
@@ -42,7 +42,8 @@ M._stream = function(opts, Provider)
instructions = table.concat(lines, "\n")
end
Path.prompts.initialize(Path.prompts.get(opts.bufnr))
local project_root = Utils.root.get()
Path.prompts.initialize(Path.prompts.get(project_root))
local template_opts = {
use_xml_format = Provider.use_xml_format,
@@ -104,11 +105,30 @@ M._stream = function(opts, Provider)
end
---@type AvantePromptOptions
local code_opts = {
return {
system_prompt = system_prompt,
messages = messages,
image_paths = image_paths,
}
end
---@param opts GeneratePromptsOptions
---@return integer
M.calculate_tokens = function(opts)
local code_opts = M.generate_prompts(opts)
local tokens = Utils.tokens.calculate_tokens(code_opts.system_prompt)
for _, message in ipairs(code_opts.messages) do
tokens = tokens + Utils.tokens.calculate_tokens(message.content)
end
return tokens
end
---@param opts StreamOptions
M._stream = function(opts)
local Provider = opts.provider or P[Config.provider]
local code_opts = M.generate_prompts(opts)
---@type string
local current_event_state = nil
@@ -248,7 +268,7 @@ M._stream = function(opts, Provider)
return active_job
end
local function _merge_response(first_response, second_response, opts, Provider)
local function _merge_response(first_response, second_response, opts)
local prompt = "\n" .. Config.dual_boost.prompt
prompt = prompt
:gsub("{{[%s]*provider1_output[%s]*}}", first_response)
@@ -259,28 +279,28 @@ local function _merge_response(first_response, second_response, opts, Provider)
-- append this reference prompt to the code_opts messages at last
opts.instructions = opts.instructions .. prompt
M._stream(opts, Provider)
M._stream(opts)
end
local function _collector_process_responses(collector, opts, Provider)
local function _collector_process_responses(collector, opts)
if not collector[1] or not collector[2] then
Utils.error("One or both responses failed to complete")
return
end
_merge_response(collector[1], collector[2], opts, Provider)
_merge_response(collector[1], collector[2], opts)
end
local function _collector_add_response(collector, index, response, opts, Provider)
local function _collector_add_response(collector, index, response, opts)
collector[index] = response
collector.count = collector.count + 1
if collector.count == 2 then
collector.timer:stop()
_collector_process_responses(collector, opts, Provider)
_collector_process_responses(collector, opts)
end
end
M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
M._dual_boost_stream = function(opts, Provider1, Provider2)
Utils.debug("Starting Dual Boost Stream")
local collector = {
@@ -299,7 +319,7 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
Utils.warn("Dual boost stream timeout reached")
collector.timer:stop()
-- Process whatever responses we have
_collector_process_responses(collector, opts, Provider)
_collector_process_responses(collector, opts)
end
end)
)
@@ -317,15 +337,19 @@ M._dual_boost_stream = function(opts, Provider, Provider1, Provider2)
return
end
Utils.debug(string.format("Response %d completed", index))
_collector_add_response(collector, index, response, opts, Provider)
_collector_add_response(collector, index, response, opts)
end,
})
end
-- Start both streams
local success, err = xpcall(function()
M._stream(create_stream_opts(1), Provider1)
M._stream(create_stream_opts(2), Provider2)
local opts1 = create_stream_opts(1)
opts1.provider = Provider1
M._stream(opts1)
local opts2 = create_stream_opts(2)
opts2.provider = Provider2
M._stream(opts2)
end, function(err) return err end)
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
end
@@ -348,12 +372,13 @@ end
---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[]
---
---@class StreamOptions: TemplateOptions
---@class GeneratePromptsOptions: TemplateOptions
---@field ask boolean
---@field bufnr integer
---@field instructions string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | nil
---
---@class StreamOptions: GeneratePromptsOptions
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
@@ -375,11 +400,10 @@ M.stream = function(opts)
return original_on_complete(err)
end)
end
local Provider = opts.provider or P[Config.provider]
if Config.dual_boost.enabled then
M._dual_boost_stream(opts, Provider, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
else
M._stream(opts, Provider)
M._stream(opts)
end
end