From bdbbdec88c3d8c6263f81bb14889ccfe3fafa631 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sun, 27 Oct 2024 02:17:35 -0400 Subject: [PATCH] feat(tokenizers): support parsing from public URL (#765) --- Cargo.lock | 97 +++++++++++++++-------- Cargo.toml | 9 +-- crates/avante-tokenizers/Cargo.toml | 4 + crates/avante-tokenizers/src/lib.rs | 116 +++++++++++++++++++++++++--- lua/avante/config.lua | 2 +- lua/avante/providers/cohere.lua | 78 ++++++++++++------- lua/avante/tokenizers.lua | 19 +++-- 7 files changed, 236 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 571e3d9..49fb531 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,9 +64,13 @@ dependencies = [ name = "avante-tokenizers" version = "0.1.0" dependencies = [ + "dirs", + "hf-hub", "mlua", + "regex", "tiktoken-rs", "tokenizers", + "ureq", ] [[package]] @@ -343,16 +347,6 @@ dependencies = [ "cc", ] -[[package]] -name = "fancy-regex" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" -dependencies = [ - "bit-set", - "regex", -] - [[package]] name = "fancy-regex" version = "0.13.0" @@ -585,9 +579,9 @@ checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" [[package]] name = "minijinja" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad" +checksum = "c9ca8daf4b0b4029777f1bc6e1aedd1aec7b74c276a43bc6f620a8e1a1c0a90e" dependencies = [ "aho-corasick", "memo-map", @@ -616,10 +610,12 @@ dependencies = [ [[package]] name = "mlua" -version = "0.10.0-beta.1" -source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f6ddbd668297c46be4bdea6c599dcc1f001a129586272d53170b7ac0a62961e" dependencies = [ "bstr", + "either", "erased-serde", "mlua-sys", "mlua_derive", @@ -632,8 +628,9 @@ dependencies = [ [[package]] name = "mlua-sys" -version = "0.6.2" -source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9eebac25c35a13285456c88ee2fde93d9aee8bcfdaf03f9d6d12be3391351ec" dependencies = [ "cc", "cfg-if", @@ -642,8 +639,9 @@ dependencies = [ [[package]] name = "mlua_derive" -version = "0.9.3" -source = "git+https://github.com/mlua-rs/mlua.git?branch=main#1634c43f0afaf7a71dc555cb6b3624250e5ff209" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e" dependencies = [ "proc-macro2", "quote", @@ -957,9 +955,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -969,9 +967,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -980,9 +978,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "ring" @@ -1160,6 +1158,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.9.8" @@ -1216,18 +1225,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", @@ -1236,16 +1245,17 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.5.9" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" +checksum = "44075987ee2486402f0808505dd65692163d243a337fc54363d49afac41087f6" dependencies = [ "anyhow", "base64 0.21.7", "bstr", - "fancy-regex 0.12.0", + "fancy-regex", "lazy_static", "parking_lot", + "regex", "rustc-hash 1.1.0", ] @@ -1273,7 +1283,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "fancy-regex 0.13.0", + "fancy-regex", "getrandom", "hf-hub", "itertools 0.12.1", @@ -1499,6 +1509,7 @@ dependencies = [ "rustls-pki-types", "serde", "serde_json", + "socks", "url", "webpki-roots", ] @@ -1602,6 +1613,28 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index ced96b3..0af97c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ version = "0.1.0" avante-tokenizers = { path = "crates/avante-tokenizers" } avante-templates = { path = "crates/avante-templates" } avante-repo-map = { path = "crates/avante-repo-map" } -minijinja = { version = "2.2.0", features = [ +minijinja = { version = "2.4.0", features = [ "loader", "json", "fuel", @@ -21,11 +21,8 @@ minijinja = { version = "2.2.0", features = [ "custom_syntax", "loop_controls", ] } -mlua = { version = "0.10.0-beta.1", features = [ - "module", - "serialize", -], git = "https://github.com/mlua-rs/mlua.git", branch = "main" } -tiktoken-rs = { version = "0.5.9" } +mlua = { version = "0.10.0", features = ["module", "serialize"] } +tiktoken-rs = { version = "0.6.0" } tokenizers = { version = "0.20.0", features = [ "esaxx_fast", "http", diff --git a/crates/avante-tokenizers/Cargo.toml b/crates/avante-tokenizers/Cargo.toml index afe9a55..fcecb63 100644 --- a/crates/avante-tokenizers/Cargo.toml +++ b/crates/avante-tokenizers/Cargo.toml @@ -12,6 +12,10 @@ license = { workspace = true } workspace = true [dependencies] +dirs = "5.0.1" +regex = "1.11.1" +hf-hub = { version = "0.3.2", features = ["default"] } +ureq = { version = "2.10.1", features = ["json", "socks-proxy"] } mlua = { workspace = true } tiktoken-rs = { workspace = true } tokenizers = { workspace = true } diff --git a/crates/avante-tokenizers/src/lib.rs b/crates/avante-tokenizers/src/lib.rs index 533b162..f725c13 100644 --- a/crates/avante-tokenizers/src/lib.rs +++ b/crates/avante-tokenizers/src/lib.rs @@ -1,4 +1,7 @@ +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use mlua::prelude::*; +use regex::Regex; +use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tiktoken_rs::{get_bpe_from_model, CoreBPE}; use tokenizers::Tokenizer; @@ -10,10 +13,10 @@ struct Tiktoken { impl Tiktoken { fn new(model: &str) -> Self { let bpe = get_bpe_from_model(model).unwrap(); - Tiktoken { bpe } + Self { bpe } } - fn encode(&self, text: &str) -> (Vec, usize, usize) { + fn encode(&self, text: &str) -> (Vec, usize, usize) { let tokens = self.bpe.encode_with_special_tokens(text); let num_tokens = tokens.len(); let num_chars = text.chars().count(); @@ -25,23 +28,53 @@ struct HuggingFaceTokenizer { tokenizer: Tokenizer, } +fn is_valid_url(url: &str) -> bool { + let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap(); + url_regex.is_match(url) +} + impl HuggingFaceTokenizer { fn new(model: &str) -> Self { - let tokenizer = Tokenizer::from_pretrained(model, None).unwrap(); - HuggingFaceTokenizer { tokenizer } + let tokenizer_path = if is_valid_url(model) { + Self::get_cached_tokenizer(model) + } else { + // Use existing HuggingFace Hub logic for model names + let identifier = model.to_string(); + let api = ApiBuilder::new().with_progress(false).build().unwrap(); + let repo = Repo::new(identifier, RepoType::Model); + let api = api.repo(repo); + api.get("tokenizer.json").unwrap() + }; + + let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap(); + Self { tokenizer } } - fn encode(&self, text: &str) -> (Vec, usize, usize) { - let encoding = self - .tokenizer - .encode(text, false) - .map_err(LuaError::external) - .unwrap(); - let tokens: Vec = encoding.get_ids().iter().map(|x| *x as usize).collect(); + fn encode(&self, text: &str) -> (Vec, usize, usize) { + let encoding = self.tokenizer.encode(text, false).unwrap(); + let tokens = encoding.get_ids().to_vec(); let num_tokens = tokens.len(); let num_chars = encoding.get_offsets().last().unwrap().1; (tokens, num_tokens, num_chars) } + + fn get_cached_tokenizer(url: &str) -> PathBuf { + let cache_dir = dirs::home_dir() + .map(|h| h.join(".cache").join("avante")) + .unwrap(); + std::fs::create_dir_all(&cache_dir).unwrap(); + + // Extract filename from URL + let filename = url.split('/').last().unwrap(); + + let cached_path = cache_dir.join(filename); + + if !cached_path.exists() { + let response = ureq::get(url).call().unwrap(); + let _ = std::fs::write(&cached_path, response.into_string().unwrap()); + } + cached_path + } } enum TokenizerType { @@ -61,7 +94,7 @@ impl State { } } -fn encode(state: &State, text: &str) -> LuaResult<(Vec, usize, usize)> { +fn encode(state: &State, text: &str) -> LuaResult<(Vec, usize, usize)> { let tokenizer = state.tokenizer.lock().unwrap(); match tokenizer.as_ref() { Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)), @@ -100,3 +133,62 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult { )?; Ok(exports) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tiktoken() { + let model = "gpt-4o"; + let source = "Hello, world!"; + let tokenizer = Tiktoken::new(model); + let (tokens, num_tokens, num_chars) = tokenizer.encode(source); + assert_eq!(tokens, vec![13225, 11, 2375, 0]); + assert_eq!(num_tokens, 4); + assert_eq!(num_chars, source.chars().count()); + } + + #[test] + fn test_hf() { + let model = "gpt2"; + let source = "Hello, world!"; + let tokenizer = HuggingFaceTokenizer::new(model); + let (tokens, num_tokens, num_chars) = tokenizer.encode(source); + assert_eq!(tokens, vec![15496, 11, 995, 0]); + assert_eq!(num_tokens, 4); + assert_eq!(num_chars, source.chars().count()); + } + + #[test] + fn test_roundtrip() { + let state = State::new(); + let source = "Hello, world!"; + let model = "gpt2"; + + from_pretrained(&state, model); + let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap(); + assert_eq!(tokens, vec![15496, 11, 995, 0]); + assert_eq!(num_tokens, 4); + assert_eq!(num_chars, source.chars().count()); + } + + // For example: https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json + // Disable testing on GitHub Actions to avoid rate limiting and file size limits + #[test] + fn test_public_url() { + if std::env::var("GITHUB_ACTIONS").is_ok() { + return; + } + let state = State::new(); + let source = "Hello, world!"; + let model = + "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"; + + from_pretrained(&state, model); + let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap(); + assert_eq!(tokens, vec![28339, 19, 3845, 8]); + assert_eq!(num_tokens, 4); + assert_eq!(num_chars, source.chars().count()); + } +} diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 95f42b8..52c6da9 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -76,7 +76,7 @@ Respect and use existing conventions, libraries, etc that are already present in }, ---@type AvanteSupportedProvider cohere = { - endpoint = "https://api.cohere.com/v1", + endpoint = "https://api.cohere.com/v2", model = "command-r-plus-08-2024", timeout = 30000, -- Timeout in milliseconds temperature = 0, diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index eab6d00..db020de 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -2,57 +2,69 @@ local Utils = require("avante.utils") local P = require("avante.providers") ---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR" +---@alias CohereStreamType "message-start" | "content-start" | "content-delta" | "content-end" | "message-end" --- ----@class CohereChatStreamResponse ----@field event_type "stream-start" | "text-generation" | "stream-end" ----@field is_finished boolean ---- ----@class CohereTextGenerationResponse: CohereChatStreamResponse +---@class CohereChatContent +---@field type? CohereStreamType ---@field text string --- ----@class CohereStreamEndResponse: CohereChatStreamResponse ----@field response CohereChatResponse ----@field finish_reason CohereFinishReason +---@class CohereChatMessage +---@field content CohereChatContent --- ----@class CohereChatResponse ----@field text string ----@field generation_id string ----@field chat_history CohereMessage[] ----@field finish_reason CohereFinishReason ----@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}} +---@class CohereChatStreamBase +---@field type CohereStreamType +---@field index integer +--- +---@class CohereChatContentDelta: CohereChatStreamBase +---@field type "content-delta" | "content-start" | "content-end" +---@field delta? { message: CohereChatMessage } +--- +---@class CohereChatMessageStart: CohereChatStreamBase +---@field type "message-start" +---@field delta { message: { role: "assistant" } } +--- +---@class CohereChatMessageEnd: CohereChatStreamBase +---@field type "message-end" +---@field delta { finish_reason: CohereFinishReason, usage: CohereChatUsage } +--- +---@class CohereChatUsage +---@field billed_units { input_tokens: integer, output_tokens: integer } +---@field tokens { input_tokens: integer, output_tokens: integer } +--- +---@alias CohereChatResponse CohereChatContentDelta | CohereChatMessageStart | CohereChatMessageEnd --- ---@class CohereMessage ----@field role? "USER" | "SYSTEM" | "CHATBOT" ----@field message string +---@field type "text" +---@field text string --- ---@class AvanteProviderFunctor local M = {} M.api_key_name = "CO_API_KEY" -M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024" +M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json" M.parse_message = function(opts) - return { - preamble = opts.system_prompt, - message = table.concat(opts.user_prompts, "\n"), + ---@type CohereMessage[] + local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt) + table.insert(acc, { type = "text", text = prompt }) + return acc + end) + local messages = { + { role = "system", content = opts.system_prompt }, + { role = "user", content = user_content }, } + return { messages = messages } end M.parse_stream_data = function(data, opts) - ---@type CohereChatStreamResponse + ---@type CohereChatResponse local json = vim.json.decode(data) - if json.is_finished then - opts.on_complete(nil) - return - end - if json.event_type ~= nil then - ---@cast json CohereStreamEndResponse - if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then + if json.type ~= nil then + if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then opts.on_complete(nil) return end - ---@cast json CohereTextGenerationResponse - if json.event_type == "text-generation" then opts.on_chunk(json.text) end + if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end end end @@ -83,4 +95,10 @@ M.parse_curl_args = function(provider, code_opts) } end +M.setup = function() + P.env.parse_envvar(M) + require("avante.tokenizers").setup(M.tokenizer_id, false) + vim.g.avante_login = true +end + return M diff --git a/lua/avante/tokenizers.lua b/lua/avante/tokenizers.lua index 35a4fd2..6dc2513 100644 --- a/lua/avante/tokenizers.lua +++ b/lua/avante/tokenizers.lua @@ -8,7 +8,9 @@ local tokenizers = nil local M = {} ---@param model "gpt-4o" | string -M.setup = function(model) +---@param warning? boolean +M.setup = function(model, warning) + warning = warning or true vim.defer_fn(function() local ok, core = pcall(require, "avante_tokenizers") if not ok then return end @@ -19,14 +21,15 @@ M.setup = function(model) core.from_pretrained(model) end, 1000) - local HF_TOKEN = os.getenv("HF_TOKEN") - if HF_TOKEN == nil and model ~= "gpt-4o" then - Utils.warn( - "Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated", - { once = true } - ) + if warning then + local HF_TOKEN = os.getenv("HF_TOKEN") + if HF_TOKEN == nil and model ~= "gpt-4o" then + Utils.warn( + "Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated", + { once = true } + ) + end end - vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1 end M.available = function() return tokenizers ~= nil end