feat(tokenizers): support parsing from public URL (#765)
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
|
||||
use mlua::prelude::*;
|
||||
use regex::Regex;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
|
||||
use tokenizers::Tokenizer;
|
||||
@@ -10,10 +13,10 @@ struct Tiktoken {
|
||||
impl Tiktoken {
|
||||
fn new(model: &str) -> Self {
|
||||
let bpe = get_bpe_from_model(model).unwrap();
|
||||
Tiktoken { bpe }
|
||||
Self { bpe }
|
||||
}
|
||||
|
||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
||||
fn encode(&self, text: &str) -> (Vec<u32>, usize, usize) {
|
||||
let tokens = self.bpe.encode_with_special_tokens(text);
|
||||
let num_tokens = tokens.len();
|
||||
let num_chars = text.chars().count();
|
||||
@@ -25,23 +28,53 @@ struct HuggingFaceTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
fn is_valid_url(url: &str) -> bool {
|
||||
let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
|
||||
url_regex.is_match(url)
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
fn new(model: &str) -> Self {
|
||||
let tokenizer = Tokenizer::from_pretrained(model, None).unwrap();
|
||||
HuggingFaceTokenizer { tokenizer }
|
||||
let tokenizer_path = if is_valid_url(model) {
|
||||
Self::get_cached_tokenizer(model)
|
||||
} else {
|
||||
// Use existing HuggingFace Hub logic for model names
|
||||
let identifier = model.to_string();
|
||||
let api = ApiBuilder::new().with_progress(false).build().unwrap();
|
||||
let repo = Repo::new(identifier, RepoType::Model);
|
||||
let api = api.repo(repo);
|
||||
api.get("tokenizer.json").unwrap()
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||
Self { tokenizer }
|
||||
}
|
||||
|
||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, false)
|
||||
.map_err(LuaError::external)
|
||||
.unwrap();
|
||||
let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect();
|
||||
fn encode(&self, text: &str) -> (Vec<u32>, usize, usize) {
|
||||
let encoding = self.tokenizer.encode(text, false).unwrap();
|
||||
let tokens = encoding.get_ids().to_vec();
|
||||
let num_tokens = tokens.len();
|
||||
let num_chars = encoding.get_offsets().last().unwrap().1;
|
||||
(tokens, num_tokens, num_chars)
|
||||
}
|
||||
|
||||
fn get_cached_tokenizer(url: &str) -> PathBuf {
|
||||
let cache_dir = dirs::home_dir()
|
||||
.map(|h| h.join(".cache").join("avante"))
|
||||
.unwrap();
|
||||
std::fs::create_dir_all(&cache_dir).unwrap();
|
||||
|
||||
// Extract filename from URL
|
||||
let filename = url.split('/').last().unwrap();
|
||||
|
||||
let cached_path = cache_dir.join(filename);
|
||||
|
||||
if !cached_path.exists() {
|
||||
let response = ureq::get(url).call().unwrap();
|
||||
let _ = std::fs::write(&cached_path, response.into_string().unwrap());
|
||||
}
|
||||
cached_path
|
||||
}
|
||||
}
|
||||
|
||||
enum TokenizerType {
|
||||
@@ -61,7 +94,7 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(state: &State, text: &str) -> LuaResult<(Vec<usize>, usize, usize)> {
|
||||
fn encode(state: &State, text: &str) -> LuaResult<(Vec<u32>, usize, usize)> {
|
||||
let tokenizer = state.tokenizer.lock().unwrap();
|
||||
match tokenizer.as_ref() {
|
||||
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
|
||||
@@ -100,3 +133,62 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
)?;
|
||||
Ok(exports)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tiktoken() {
|
||||
let model = "gpt-4o";
|
||||
let source = "Hello, world!";
|
||||
let tokenizer = Tiktoken::new(model);
|
||||
let (tokens, num_tokens, num_chars) = tokenizer.encode(source);
|
||||
assert_eq!(tokens, vec![13225, 11, 2375, 0]);
|
||||
assert_eq!(num_tokens, 4);
|
||||
assert_eq!(num_chars, source.chars().count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hf() {
|
||||
let model = "gpt2";
|
||||
let source = "Hello, world!";
|
||||
let tokenizer = HuggingFaceTokenizer::new(model);
|
||||
let (tokens, num_tokens, num_chars) = tokenizer.encode(source);
|
||||
assert_eq!(tokens, vec![15496, 11, 995, 0]);
|
||||
assert_eq!(num_tokens, 4);
|
||||
assert_eq!(num_chars, source.chars().count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip() {
|
||||
let state = State::new();
|
||||
let source = "Hello, world!";
|
||||
let model = "gpt2";
|
||||
|
||||
from_pretrained(&state, model);
|
||||
let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap();
|
||||
assert_eq!(tokens, vec![15496, 11, 995, 0]);
|
||||
assert_eq!(num_tokens, 4);
|
||||
assert_eq!(num_chars, source.chars().count());
|
||||
}
|
||||
|
||||
// For example: https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json
|
||||
// Disable testing on GitHub Actions to avoid rate limiting and file size limits
|
||||
#[test]
|
||||
fn test_public_url() {
|
||||
if std::env::var("GITHUB_ACTIONS").is_ok() {
|
||||
return;
|
||||
}
|
||||
let state = State::new();
|
||||
let source = "Hello, world!";
|
||||
let model =
|
||||
"https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json";
|
||||
|
||||
from_pretrained(&state, model);
|
||||
let (tokens, num_tokens, num_chars) = encode(&state, "Hello, world!").unwrap();
|
||||
assert_eq!(tokens, vec![28339, 19, 3845, 8]);
|
||||
assert_eq!(num_tokens, 4);
|
||||
assert_eq!(num_chars, source.chars().count());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user