feat: tokenizers (#429)

* feat: tokenizers

This reverts commit d5a4db8321.

* fix(inputs): #422

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

---------

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2024-08-31 13:39:50 -04:00
committed by GitHub
parent 534b1e6bec
commit 0557deeab7
28 changed files with 3553 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ local Utils = require("avante.utils")
---@field ask fun(): boolean
---@field edit fun(): nil
---@field refresh fun(): nil
---@field build fun(): boolean
---@field toggle avante.ApiToggle
return setmetatable({}, {

View File

@@ -285,6 +285,7 @@ M.BASE_PROVIDER_KEYS = {
-- internal
"local",
"_shellenv",
"tokenizer_id",
}
---@return {width: integer, height: integer}

View File

@@ -34,6 +34,9 @@ H.commands = function()
cmd("Refresh", function()
M.refresh()
end, { desc = "avante: refresh windows" })
cmd("Build", function()
M.build()
end, { desc = "avante: build dependencies" })
end
H.keymaps = function()
@@ -91,6 +94,34 @@ end
H.augroup = api.nvim_create_augroup("avante_autocmds", { clear = true })
H.autocmds = function()
local ok, LazyConfig = pcall(require, "lazy.core.config")
if ok then
local name = "avante.nvim"
local load_path = function()
require("avante_lib").load()
end
if LazyConfig.plugins[name] and LazyConfig.plugins[name]._.loaded then
vim.schedule(load_path)
else
api.nvim_create_autocmd("User", {
pattern = "LazyLoad",
callback = function(event)
if event.data == name then
load_path()
return true
end
end,
})
end
api.nvim_create_autocmd("User", {
pattern = "VeryLazy",
callback = load_path,
})
end
api.nvim_create_autocmd("TabEnter", {
group = H.augroup,
pattern = "*",
@@ -221,6 +252,53 @@ setmetatable(M.toggle, {
end,
})
local function to_windows_path(path)
local winpath = path:gsub("/", "\\")
if winpath:match("^%a:") then
winpath = winpath:sub(1, 2):upper() .. winpath:sub(3)
end
winpath = winpath:gsub("\\$", "")
return winpath
end
M.build = H.api(function()
local dirname = Utils.trim(string.sub(debug.getinfo(1).source, 2, #"/init.lua" * -1), { suffix = "/" })
local git_root = vim.fs.find(".git", { path = dirname, upward = true })[1]
local build_directory = git_root and vim.fn.fnamemodify(git_root, ":h") or (dirname .. "/../../")
if not vim.fn.executable("cargo") then
error("Building avante.nvim requires cargo to be installed.", 2)
end
---@type string[]
local cmd
local os_name = Utils.get_os_name()
if vim.tbl_contains({ "linux", "darwin" }, os_name) then
cmd = { "sh", "-c", ("make -C %s"):format(build_directory) }
elseif os_name == "windows" then
build_directory = to_windows_path(build_directory)
cmd = {
"powershell",
"-ExecutionPolicy",
"Bypass",
"-File",
("%s\\Build.ps1"):format(build_directory),
"-WorkingDirectory",
build_directory,
}
else
error("Unsupported operating system: " .. os_name, 2)
end
local job = vim.system(cmd, { text = true }):wait()
return vim.tbl_contains({ 0 }, job.code) and true or false
end)
M.ask = H.api(function()
M.toggle()
end)
@@ -283,7 +361,7 @@ function M.setup(opts)
return
end
require("avante.history").setup()
require("avante.path").setup()
require("avante.highlights").setup()
require("avante.diff").setup()
require("avante.providers").setup()

View File

@@ -2,10 +2,14 @@ local fn, api = vim.fn, vim.api
local Path = require("plenary.path")
local Config = require("avante.config")
---@class avante.Path
---@field history_path Path
---@field cache_path Path
local P = {}
local M = {}
local H = {}
---@param bufnr integer
---@return string
H.filename = function(bufnr)
@@ -39,11 +43,20 @@ M.save = function(bufnr, history)
history_file:write(vim.json.encode(history), "w")
end
M.setup = function()
local history_dir = Path:new(Config.history.storage_path)
if not history_dir:exists() then
history_dir:mkdir({ parents = true })
P.history = M
P.setup = function()
local history_path = Path:new(Config.history.storage_path)
if not history_path:exists() then
history_path:mkdir({ parents = true })
end
P.history_path = history_path
local cache_path = Path:new(vim.fn.stdpath("cache") .. "/avante")
if not cache_path:exists() then
cache_path:mkdir({ parents = true })
end
P.cache_path = cache_path
end
return M
return P

View File

@@ -12,6 +12,7 @@ local O = require("avante.providers").openai
local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.tokenizer_id = "gpt-4o"
M.parse_message = O.parse_message
M.parse_response = O.parse_response

View File

@@ -6,6 +6,7 @@ local P = require("avante.providers")
local M = {}
M.api_key_name = "ANTHROPIC_API_KEY"
M.tokenizer_id = "gpt-4o"
---@param prompt_opts AvantePromptOptions
M.parse_message = function(prompt_opts)
@@ -26,7 +27,7 @@ M.parse_message = function(prompt_opts)
local user_prompts_with_length = {}
for idx, user_prompt in ipairs(prompt_opts.user_prompts) do
table.insert(user_prompts_with_length, { idx = idx, length = #user_prompt })
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
end
table.sort(user_prompts_with_length, function(a, b)

View File

@@ -29,6 +29,7 @@ local P = require("avante.providers")
local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
M.parse_message = function(opts)
local user_prompt = table.concat(opts.user_prompts, "\n\n")

View File

@@ -127,6 +127,7 @@ end
M.state = nil
M.api_key_name = P.AVANTE_INTERNAL_KEY
M.tokenizer_id = "gpt-4o"
M.parse_message = function(opts)
return {
@@ -166,6 +167,7 @@ M.setup = function()
M.state = { github_token = nil, oauth_token = H.get_oauth_token() }
H.refresh_token()
end
require("avante.tokenizers").setup(M.tokenizer_id)
vim.g.avante_login = true
end

View File

@@ -6,6 +6,7 @@ local Clipboard = require("avante.clipboard")
local M = {}
M.api_key_name = "GEMINI_API_KEY"
M.tokenizer_id = "google/gemma-2b"
M.parse_message = function(opts)
local message_content = {}

View File

@@ -69,6 +69,7 @@ local Dressing = require("avante.ui.dressing")
---@field setup fun(): nil
---@field has fun(): boolean
---@field api_key_name string
---@field tokenizer_id string | "gpt-4o"
---@field model? string
---@field parse_api_key fun(): string | nil
---@field parse_stream_data? AvanteStreamParser
@@ -269,6 +270,11 @@ M = setmetatable(M, {
return E.parse_envvar(t[k])
end
-- default to gpt-4o as tokenizer
if t[k].tokenizer_id == nil then
t[k].tokenizer_id = "gpt-4o"
end
if t[k].has == nil then
t[k].has = function()
return E.parse_envvar(t[k]) ~= nil
@@ -280,6 +286,7 @@ M = setmetatable(M, {
if not E.is_local(k) then
t[k].parse_api_key()
end
require("avante.tokenizers").setup(t[k].tokenizer_id)
end
end

View File

@@ -26,6 +26,7 @@ local P = require("avante.providers")
local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.tokenizer_id = "gpt-4o"
---@param opts AvantePromptOptions
M.get_user_message = function(opts)

View File

@@ -4,7 +4,7 @@ local fn = vim.fn
local Split = require("nui.split")
local event = require("nui.utils.autocmd").event
local History = require("avante.history")
local Path = require("avante.path")
local Config = require("avante.config")
local Diff = require("avante.diff")
local Llm = require("avante.llm")
@@ -1170,7 +1170,7 @@ function Sidebar:get_commands()
end,
clear = function(args, cb)
local chat_history = {}
History.save(self.code.bufnr, chat_history)
Path.history.save(self.code.bufnr, chat_history)
self:update_content("Chat history cleared", { focus = false, scroll = false })
vim.defer_fn(function()
self:close()
@@ -1242,7 +1242,7 @@ function Sidebar:create_input()
return
end
local chat_history = History.load(self.code.bufnr)
local chat_history = Path.history.load(self.code.bufnr)
---@param request string
local function handle_submit(request)
@@ -1359,7 +1359,7 @@ function Sidebar:create_input()
request = request,
response = full_response,
})
History.save(self.code.bufnr, chat_history)
Path.history.save(self.code.bufnr, chat_history)
end
Llm.stream({
@@ -1587,7 +1587,7 @@ function Sidebar:get_selected_code_size()
end
function Sidebar:render()
local chat_history = History.load(self.code.bufnr)
local chat_history = Path.history.load(self.code.bufnr)
local get_position = function()
if Config.layout == "vertical" then

66
lua/avante/tokenizers.lua Normal file
View File

@@ -0,0 +1,66 @@
local Utils = require("avante.utils")
---@class AvanteTokenizer
---@field from_pretrained fun(model: string): nil
---@field encode fun(string): integer[]
local tokenizers = nil
local M = {}
---@param model "gpt-4o" | string
M.setup = function(model)
local ok, core = pcall(require, "avante_tokenizers")
if not ok then
return
end
---@cast core AvanteTokenizer
if tokenizers == nil then
tokenizers = core
end
local HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN == nil and model ~= "gpt-4o" then
Utils.warn(
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
{ once = true }
)
end
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
---@cast core AvanteTokenizer
core.from_pretrained(model)
end
M.available = function()
return tokenizers ~= nil
end
---@param prompt string
M.encode = function(prompt)
if not tokenizers then
return nil
end
if not prompt or prompt == "" then
return nil
end
if type(prompt) ~= "string" then
error("Prompt is not type string", 2)
end
return tokenizers.encode(prompt)
end
---@param prompt string
M.count = function(prompt)
if not tokenizers then
return math.ceil(#prompt * 0.5)
end
local tokens = M.encode(prompt)
if not tokens then
return 0
end
return #tokens
end
return M

View File

@@ -1,4 +1,6 @@
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
local Tokenizer = require("avante.tokenizers")
---@class avante.utils.tokens
local Tokens = {}
@@ -11,6 +13,10 @@ local cost_per_token = {
---@param text string The text to calculate the number of tokens for.
---@return integer The number of tokens in the given text.
function Tokens.calculate_tokens(text)
if Tokenizer.available() then
return Tokenizer.count(text)
end
local tokens = 0
local current_token = ""
for char in text:gmatch(".") do

22
lua/avante_lib.lua Normal file
View File

@@ -0,0 +1,22 @@
local M = {}
local function get_library_path()
local os_name = require("avante.utils").get_os_name()
local ext = os_name == "linux" and "so" or (os_name == "darwin" and "dylib" or "dll")
local dirname = string.sub(debug.getinfo(1).source, 2, #"/avante_lib.lua" * -1)
return dirname .. ("../build/?.%s"):format(ext)
end
---@type fun(s: string): string
local trim_semicolon = function(s)
return s:sub(-1) == ";" and s:sub(1, -2) or s
end
M.load = function()
local library_path = get_library_path()
if not string.find(package.cpath, library_path, 1, true) then
package.cpath = trim_semicolon(package.cpath) .. ";" .. library_path
end
end
return M