chore(rust): fix current clippy lint (#504)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -8,13 +8,13 @@ struct Tiktoken {
|
||||
}
|
||||
|
||||
impl Tiktoken {
|
||||
fn new(model: String) -> Self {
|
||||
let bpe = get_bpe_from_model(&model).unwrap();
|
||||
fn new(model: &str) -> Self {
|
||||
let bpe = get_bpe_from_model(model).unwrap();
|
||||
Tiktoken { bpe }
|
||||
}
|
||||
|
||||
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) {
|
||||
let tokens = self.bpe.encode_with_special_tokens(&text);
|
||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
||||
let tokens = self.bpe.encode_with_special_tokens(text);
|
||||
let num_tokens = tokens.len();
|
||||
let num_chars = text.chars().count();
|
||||
(tokens, num_tokens, num_chars)
|
||||
@@ -26,13 +26,17 @@ struct HuggingFaceTokenizer {
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
fn new(model: String) -> Self {
|
||||
fn new(model: &str) -> Self {
|
||||
let tokenizer = Tokenizer::from_pretrained(model, None).unwrap();
|
||||
HuggingFaceTokenizer { tokenizer }
|
||||
}
|
||||
|
||||
fn encode(&self, text: String) -> (Vec<usize>, usize, usize) {
|
||||
let encoding = self.tokenizer.encode(text, false).unwrap();
|
||||
fn encode(&self, text: &str) -> (Vec<usize>, usize, usize) {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, false)
|
||||
.map_err(LuaError::external)
|
||||
.unwrap();
|
||||
let tokens: Vec<usize> = encoding.get_ids().iter().map(|x| *x as usize).collect();
|
||||
let num_tokens = tokens.len();
|
||||
let num_chars = encoding.get_offsets().last().unwrap().1;
|
||||
@@ -57,7 +61,7 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)> {
|
||||
fn encode(state: &State, text: &str) -> LuaResult<(Vec<usize>, usize, usize)> {
|
||||
let tokenizer = state.tokenizer.lock().unwrap();
|
||||
match tokenizer.as_ref() {
|
||||
Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)),
|
||||
@@ -68,9 +72,9 @@ fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)>
|
||||
}
|
||||
}
|
||||
|
||||
fn from_pretrained(state: &State, model: String) {
|
||||
fn from_pretrained(state: &State, model: &str) {
|
||||
let mut tokenizer_mutex = state.tokenizer.lock().unwrap();
|
||||
*tokenizer_mutex = Some(match model.as_str() {
|
||||
*tokenizer_mutex = Some(match model {
|
||||
"gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)),
|
||||
_ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)),
|
||||
});
|
||||
@@ -86,13 +90,13 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
exports.set(
|
||||
"from_pretrained",
|
||||
lua.create_function(move |_, model: String| {
|
||||
from_pretrained(&state, model);
|
||||
from_pretrained(&state, model.as_str());
|
||||
Ok(())
|
||||
})?,
|
||||
)?;
|
||||
exports.set(
|
||||
"encode",
|
||||
lua.create_function(move |_, text: String| encode(&state_clone, text))?,
|
||||
lua.create_function(move |_, text: String| encode(&state_clone, text.as_str()))?,
|
||||
)?;
|
||||
Ok(exports)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user