feat: tokenizers (#429)
* feat: tokenizers
This reverts commit d5a4db8321.
* fix(inputs): #422
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
---------
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -10,6 +10,7 @@ local Utils = require("avante.utils")
|
||||
---@field ask fun(): boolean
|
||||
---@field edit fun(): nil
|
||||
---@field refresh fun(): nil
|
||||
---@field build fun(): boolean
|
||||
---@field toggle avante.ApiToggle
|
||||
|
||||
return setmetatable({}, {
|
||||
|
||||
@@ -285,6 +285,7 @@ M.BASE_PROVIDER_KEYS = {
|
||||
-- internal
|
||||
"local",
|
||||
"_shellenv",
|
||||
"tokenizer_id",
|
||||
}
|
||||
|
||||
---@return {width: integer, height: integer}
|
||||
|
||||
@@ -34,6 +34,9 @@ H.commands = function()
|
||||
cmd("Refresh", function()
|
||||
M.refresh()
|
||||
end, { desc = "avante: refresh windows" })
|
||||
cmd("Build", function()
|
||||
M.build()
|
||||
end, { desc = "avante: build dependencies" })
|
||||
end
|
||||
|
||||
H.keymaps = function()
|
||||
@@ -91,6 +94,34 @@ end
|
||||
H.augroup = api.nvim_create_augroup("avante_autocmds", { clear = true })
|
||||
|
||||
H.autocmds = function()
|
||||
local ok, LazyConfig = pcall(require, "lazy.core.config")
|
||||
|
||||
if ok then
|
||||
local name = "avante.nvim"
|
||||
local load_path = function()
|
||||
require("avante_lib").load()
|
||||
end
|
||||
|
||||
if LazyConfig.plugins[name] and LazyConfig.plugins[name]._.loaded then
|
||||
vim.schedule(load_path)
|
||||
else
|
||||
api.nvim_create_autocmd("User", {
|
||||
pattern = "LazyLoad",
|
||||
callback = function(event)
|
||||
if event.data == name then
|
||||
load_path()
|
||||
return true
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
api.nvim_create_autocmd("User", {
|
||||
pattern = "VeryLazy",
|
||||
callback = load_path,
|
||||
})
|
||||
end
|
||||
|
||||
api.nvim_create_autocmd("TabEnter", {
|
||||
group = H.augroup,
|
||||
pattern = "*",
|
||||
@@ -221,6 +252,53 @@ setmetatable(M.toggle, {
|
||||
end,
|
||||
})
|
||||
|
||||
local function to_windows_path(path)
|
||||
local winpath = path:gsub("/", "\\")
|
||||
|
||||
if winpath:match("^%a:") then
|
||||
winpath = winpath:sub(1, 2):upper() .. winpath:sub(3)
|
||||
end
|
||||
|
||||
winpath = winpath:gsub("\\$", "")
|
||||
|
||||
return winpath
|
||||
end
|
||||
|
||||
M.build = H.api(function()
|
||||
local dirname = Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/init.lua" * -1), { suffix = "/" })
|
||||
local git_root = vim.fs.find(".git", { path = dirname, upward = true })[1]
|
||||
local build_directory = git_root and vim.fn.fnamemodify(git_root, ":h") or (dirname .. "/../../")
|
||||
|
||||
if not vim.fn.executable("cargo") then
|
||||
error("Building avante.nvim requires cargo to be installed.", 2)
|
||||
end
|
||||
|
||||
---@type string[]
|
||||
local cmd
|
||||
local os_name = Utils.get_os_name()
|
||||
|
||||
if vim.tbl_contains({ "linux", "darwin" }, os_name) then
|
||||
cmd = { "sh", "-c", ("make -C %s"):format(build_directory) }
|
||||
elseif os_name == "windows" then
|
||||
build_directory = to_windows_path(build_directory)
|
||||
cmd = {
|
||||
"powershell",
|
||||
"-ExecutionPolicy",
|
||||
"Bypass",
|
||||
"-File",
|
||||
("%s\\Build.ps1"):format(build_directory),
|
||||
"-WorkingDirectory",
|
||||
build_directory,
|
||||
}
|
||||
else
|
||||
error("Unsupported operating system: " .. os_name, 2)
|
||||
end
|
||||
|
||||
local job = vim.system(cmd, { text = true }):wait()
|
||||
|
||||
return vim.tbl_contains({ 0 }, job.code) and true or false
|
||||
end)
|
||||
|
||||
M.ask = H.api(function()
|
||||
M.toggle()
|
||||
end)
|
||||
@@ -283,7 +361,7 @@ function M.setup(opts)
|
||||
return
|
||||
end
|
||||
|
||||
require("avante.history").setup()
|
||||
require("avante.path").setup()
|
||||
require("avante.highlights").setup()
|
||||
require("avante.diff").setup()
|
||||
require("avante.providers").setup()
|
||||
|
||||
@@ -2,10 +2,14 @@ local fn, api = vim.fn, vim.api
|
||||
local Path = require("plenary.path")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class avante.Path
|
||||
---@field history_path Path
|
||||
---@field cache_path Path
|
||||
local P = {}
|
||||
|
||||
local M = {}
|
||||
|
||||
local H = {}
|
||||
|
||||
---@param bufnr integer
|
||||
---@return string
|
||||
H.filename = function(bufnr)
|
||||
@@ -39,11 +43,20 @@ M.save = function(bufnr, history)
|
||||
history_file:write(vim.json.encode(history), "w")
|
||||
end
|
||||
|
||||
M.setup = function()
|
||||
local history_dir = Path:new(Config.history.storage_path)
|
||||
if not history_dir:exists() then
|
||||
history_dir:mkdir({ parents = true })
|
||||
P.history = M
|
||||
|
||||
P.setup = function()
|
||||
local history_path = Path:new(Config.history.storage_path)
|
||||
if not history_path:exists() then
|
||||
history_path:mkdir({ parents = true })
|
||||
end
|
||||
P.history_path = history_path
|
||||
|
||||
local cache_path = Path:new(vim.fn.stdpath("cache") .. "/avante")
|
||||
if not cache_path:exists() then
|
||||
cache_path:mkdir({ parents = true })
|
||||
end
|
||||
P.cache_path = cache_path
|
||||
end
|
||||
|
||||
return M
|
||||
return P
|
||||
@@ -12,6 +12,7 @@ local O = require("avante.providers").openai
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
|
||||
M.parse_message = O.parse_message
|
||||
M.parse_response = O.parse_response
|
||||
|
||||
@@ -6,6 +6,7 @@ local P = require("avante.providers")
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
|
||||
---@param prompt_opts AvantePromptOptions
|
||||
M.parse_message = function(prompt_opts)
|
||||
@@ -26,7 +27,7 @@ M.parse_message = function(prompt_opts)
|
||||
|
||||
local user_prompts_with_length = {}
|
||||
for idx, user_prompt in ipairs(prompt_opts.user_prompts) do
|
||||
table.insert(user_prompts_with_length, { idx = idx, length = #user_prompt })
|
||||
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
|
||||
end
|
||||
|
||||
table.sort(user_prompts_with_length, function(a, b)
|
||||
|
||||
@@ -29,6 +29,7 @@ local P = require("avante.providers")
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "CO_API_KEY"
|
||||
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
|
||||
|
||||
M.parse_message = function(opts)
|
||||
local user_prompt = table.concat(opts.user_prompts, "\n\n")
|
||||
|
||||
@@ -127,6 +127,7 @@ end
|
||||
M.state = nil
|
||||
|
||||
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
|
||||
M.parse_message = function(opts)
|
||||
return {
|
||||
@@ -166,6 +167,7 @@ M.setup = function()
|
||||
M.state = { github_token = nil, oauth_token = H.get_oauth_token() }
|
||||
H.refresh_token()
|
||||
end
|
||||
require("avante.tokenizers").setup(M.tokenizer_id)
|
||||
vim.g.avante_login = true
|
||||
end
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ local Clipboard = require("avante.clipboard")
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "GEMINI_API_KEY"
|
||||
M.tokenizer_id = "google/gemma-2b"
|
||||
|
||||
M.parse_message = function(opts)
|
||||
local message_content = {}
|
||||
|
||||
@@ -69,6 +69,7 @@ local Dressing = require("avante.ui.dressing")
|
||||
---@field setup fun(): nil
|
||||
---@field has fun(): boolean
|
||||
---@field api_key_name string
|
||||
---@field tokenizer_id string | "gpt-4o"
|
||||
---@field model? string
|
||||
---@field parse_api_key fun(): string | nil
|
||||
---@field parse_stream_data? AvanteStreamParser
|
||||
@@ -269,6 +270,11 @@ M = setmetatable(M, {
|
||||
return E.parse_envvar(t[k])
|
||||
end
|
||||
|
||||
-- default to gpt-4o as tokenizer
|
||||
if t[k].tokenizer_id == nil then
|
||||
t[k].tokenizer_id = "gpt-4o"
|
||||
end
|
||||
|
||||
if t[k].has == nil then
|
||||
t[k].has = function()
|
||||
return E.parse_envvar(t[k]) ~= nil
|
||||
@@ -280,6 +286,7 @@ M = setmetatable(M, {
|
||||
if not E.is_local(k) then
|
||||
t[k].parse_api_key()
|
||||
end
|
||||
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ local P = require("avante.providers")
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "OPENAI_API_KEY"
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
|
||||
---@param opts AvantePromptOptions
|
||||
M.get_user_message = function(opts)
|
||||
|
||||
@@ -4,7 +4,7 @@ local fn = vim.fn
|
||||
local Split = require("nui.split")
|
||||
local event = require("nui.utils.autocmd").event
|
||||
|
||||
local History = require("avante.history")
|
||||
local Path = require("avante.path")
|
||||
local Config = require("avante.config")
|
||||
local Diff = require("avante.diff")
|
||||
local Llm = require("avante.llm")
|
||||
@@ -1170,7 +1170,7 @@ function Sidebar:get_commands()
|
||||
end,
|
||||
clear = function(args, cb)
|
||||
local chat_history = {}
|
||||
History.save(self.code.bufnr, chat_history)
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
self:update_content("Chat history cleared", { focus = false, scroll = false })
|
||||
vim.defer_fn(function()
|
||||
self:close()
|
||||
@@ -1242,7 +1242,7 @@ function Sidebar:create_input()
|
||||
return
|
||||
end
|
||||
|
||||
local chat_history = History.load(self.code.bufnr)
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
|
||||
---@param request string
|
||||
local function handle_submit(request)
|
||||
@@ -1359,7 +1359,7 @@ function Sidebar:create_input()
|
||||
request = request,
|
||||
response = full_response,
|
||||
})
|
||||
History.save(self.code.bufnr, chat_history)
|
||||
Path.history.save(self.code.bufnr, chat_history)
|
||||
end
|
||||
|
||||
Llm.stream({
|
||||
@@ -1587,7 +1587,7 @@ function Sidebar:get_selected_code_size()
|
||||
end
|
||||
|
||||
function Sidebar:render()
|
||||
local chat_history = History.load(self.code.bufnr)
|
||||
local chat_history = Path.history.load(self.code.bufnr)
|
||||
|
||||
local get_position = function()
|
||||
if Config.layout == "vertical" then
|
||||
|
||||
66
lua/avante/tokenizers.lua
Normal file
66
lua/avante/tokenizers.lua
Normal file
@@ -0,0 +1,66 @@
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
---@class AvanteTokenizer
|
||||
---@field from_pretrained fun(model: string): nil
|
||||
---@field encode fun(string): integer[]
|
||||
local tokenizers = nil
|
||||
|
||||
local M = {}
|
||||
|
||||
---@param model "gpt-4o" | string
|
||||
M.setup = function(model)
|
||||
local ok, core = pcall(require, "avante_tokenizers")
|
||||
if not ok then
|
||||
return
|
||||
end
|
||||
---@cast core AvanteTokenizer
|
||||
if tokenizers == nil then
|
||||
tokenizers = core
|
||||
end
|
||||
|
||||
local HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
||||
Utils.warn(
|
||||
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
|
||||
{ once = true }
|
||||
)
|
||||
end
|
||||
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
|
||||
|
||||
---@cast core AvanteTokenizer
|
||||
core.from_pretrained(model)
|
||||
end
|
||||
|
||||
M.available = function()
|
||||
return tokenizers ~= nil
|
||||
end
|
||||
|
||||
---@param prompt string
|
||||
M.encode = function(prompt)
|
||||
if not tokenizers then
|
||||
return nil
|
||||
end
|
||||
if not prompt or prompt == "" then
|
||||
return nil
|
||||
end
|
||||
if type(prompt) ~= "string" then
|
||||
error("Prompt is not type string", 2)
|
||||
end
|
||||
|
||||
return tokenizers.encode(prompt)
|
||||
end
|
||||
|
||||
---@param prompt string
|
||||
M.count = function(prompt)
|
||||
if not tokenizers then
|
||||
return math.ceil(#prompt * 0.5)
|
||||
end
|
||||
|
||||
local tokens = M.encode(prompt)
|
||||
if not tokens then
|
||||
return 0
|
||||
end
|
||||
return #tokens
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -1,4 +1,6 @@
|
||||
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
|
||||
local Tokenizer = require("avante.tokenizers")
|
||||
|
||||
---@class avante.utils.tokens
|
||||
local Tokens = {}
|
||||
|
||||
@@ -11,6 +13,10 @@ local cost_per_token = {
|
||||
---@param text string The text to calculate the number of tokens for.
|
||||
---@return integer The number of tokens in the given text.
|
||||
function Tokens.calculate_tokens(text)
|
||||
if Tokenizer.available() then
|
||||
return Tokenizer.count(text)
|
||||
end
|
||||
|
||||
local tokens = 0
|
||||
local current_token = ""
|
||||
for char in text:gmatch(".") do
|
||||
|
||||
22
lua/avante_lib.lua
Normal file
22
lua/avante_lib.lua
Normal file
@@ -0,0 +1,22 @@
|
||||
local M = {}
|
||||
|
||||
local function get_library_path()
|
||||
local os_name = require("avante.utils").get_os_name()
|
||||
local ext = os_name == "linux" and "so" or (os_name == "darwin" and "dylib" or "dll")
|
||||
local dirname = string.sub(debug.getinfo(1).source, 2, #"/avante_lib.lua" * -1)
|
||||
return dirname .. ("../build/?.%s"):format(ext)
|
||||
end
|
||||
|
||||
---@type fun(s: string): string
|
||||
local trim_semicolon = function(s)
|
||||
return s:sub(-1) == ";" and s:sub(1, -2) or s
|
||||
end
|
||||
|
||||
M.load = function()
|
||||
local library_path = get_library_path()
|
||||
if not string.find(package.cpath, library_path, 1, true) then
|
||||
package.cpath = trim_semicolon(package.cpath) .. ";" .. library_path
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
Reference in New Issue
Block a user