Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -12,17 +12,9 @@ license = { workspace = true }
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
mlua = { version = "0.10.0-beta.1", features = [
|
||||
"module",
|
||||
"serialize",
|
||||
], git = "https://github.com/mlua-rs/mlua.git", branch = "main" }
|
||||
tiktoken-rs = "0.5.9"
|
||||
tokenizers = { version = "0.20.0", features = [
|
||||
"esaxx_fast",
|
||||
"http",
|
||||
"unstable_wasm",
|
||||
"onig",
|
||||
], default-features = false }
|
||||
mlua = { workspace = true }
|
||||
tiktoken-rs = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
|
||||
[features]
|
||||
lua51 = ["mlua/lua51"]
|
||||
|
||||
@@ -68,13 +68,12 @@ fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)>
|
||||
}
|
||||
}
|
||||
|
||||
fn from_pretrained(state: &State, model: String) -> LuaResult<()> {
|
||||
fn from_pretrained(state: &State, model: String) {
|
||||
let mut tokenizer_mutex = state.tokenizer.lock().unwrap();
|
||||
*tokenizer_mutex = Some(match model.as_str() {
|
||||
"gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)),
|
||||
_ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)),
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[mlua::lua_module]
|
||||
@@ -86,7 +85,10 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let exports = lua.create_table()?;
|
||||
exports.set(
|
||||
"from_pretrained",
|
||||
lua.create_function(move |_, model: String| from_pretrained(&state, model))?,
|
||||
lua.create_function(move |_, model: String| {
|
||||
from_pretrained(&state, model);
|
||||
Ok(())
|
||||
})?,
|
||||
)?;
|
||||
exports.set(
|
||||
"encode",
|
||||
|
||||
Reference in New Issue
Block a user