feat(templates): avanterules filetype support (closes #254) (#466)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2024-09-03 04:09:13 -04:00
committed by GitHub
parent 054695cc63
commit 4ad913435c
31 changed files with 962 additions and 265 deletions

View File

@@ -0,0 +1,24 @@
[lib]
crate-type = ["cdylib"]
[package]
name = "avante-templates"
edition.workspace = true
rust-version.workspace = true
license.workspace = true
version.workspace = true
[dependencies]
mlua = { workspace = true }
minijinja = { workspace = true }
serde = { workspace = true, features = ["derive"] }
[lints]
workspace = true
[features]
lua51 = ["mlua/lua51"]
lua52 = ["mlua/lua52"]
lua53 = ["mlua/lua53"]
lua54 = ["mlua/lua54"]
luajit = ["mlua/luajit"]

View File

@@ -0,0 +1,91 @@
use minijinja::{context, path_loader, Environment};
use mlua::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
struct State<'a> {
environment: Mutex<Option<Environment<'a>>>,
}
impl<'a> State<'a> {
fn new() -> Self {
State {
environment: Mutex::new(None),
}
}
}
#[derive(Serialize, Deserialize)]
struct TemplateContext {
use_xml_format: bool,
ask: bool,
question: String,
code_lang: String,
file_content: String,
selected_code: Option<String>,
project_context: Option<String>,
memory_context: Option<String>,
}
// Given the file name registered after add, the context table in Lua, resulted in a formatted
// Lua string
fn render(state: &State, template: String, context: TemplateContext) -> LuaResult<String> {
let environment = state.environment.lock().unwrap();
match environment.as_ref() {
Some(environment) => {
let template = environment
.get_template(&template)
.map_err(LuaError::external)
.unwrap();
Ok(template
.render(context! {
use_xml_format => context.use_xml_format,
ask => context.ask,
question => context.question,
code_lang => context.code_lang,
file_content => context.file_content,
selected_code => context.selected_code,
project_context => context.project_context,
memory_context => context.memory_context,
})
.map_err(LuaError::external)
.unwrap())
}
None => Err(LuaError::RuntimeError(
"Environment not initialized".to_string(),
)),
}
}
fn initialize(state: &State, directory: String) {
let mut environment_mutex = state.environment.lock().unwrap();
// add directory as a base path for base directory template path
let mut env = Environment::new();
env.set_loader(path_loader(directory));
*environment_mutex = Some(env);
}
#[mlua::lua_module]
fn avante_templates(lua: &Lua) -> LuaResult<LuaTable> {
let core = State::new();
let state = Arc::new(core);
let state_clone = Arc::clone(&state);
let exports = lua.create_table()?;
exports.set(
"initialize",
lua.create_function(move |_, model: String| {
initialize(&state, model);
Ok(())
})?,
)?;
exports.set(
"render",
lua.create_function_mut(move |lua, (template, context): (String, LuaValue)| {
let ctx = lua.from_value(context)?;
render(&state_clone, template, ctx)
})?,
)?;
Ok(exports)
}

View File

@@ -12,17 +12,9 @@ license = { workspace = true }
workspace = true
[dependencies]
mlua = { version = "0.10.0-beta.1", features = [
"module",
"serialize",
], git = "https://github.com/mlua-rs/mlua.git", branch = "main" }
tiktoken-rs = "0.5.9"
tokenizers = { version = "0.20.0", features = [
"esaxx_fast",
"http",
"unstable_wasm",
"onig",
], default-features = false }
mlua = { workspace = true }
tiktoken-rs = { workspace = true }
tokenizers = { workspace = true }
[features]
lua51 = ["mlua/lua51"]

View File

@@ -68,13 +68,12 @@ fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)>
}
}
fn from_pretrained(state: &State, model: String) -> LuaResult<()> {
fn from_pretrained(state: &State, model: String) {
let mut tokenizer_mutex = state.tokenizer.lock().unwrap();
*tokenizer_mutex = Some(match model.as_str() {
"gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)),
_ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)),
});
Ok(())
}
#[mlua::lua_module]
@@ -86,7 +85,10 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult<LuaTable> {
let exports = lua.create_table()?;
exports.set(
"from_pretrained",
lua.create_function(move |_, model: String| from_pretrained(&state, model))?,
lua.create_function(move |_, model: String| {
from_pretrained(&state, model);
Ok(())
})?,
)?;
exports.set(
"encode",