From 579ef12f76db285ba104585ea76ef87af744443c Mon Sep 17 00:00:00 2001 From: Vasil Markoukin Date: Mon, 4 Nov 2024 03:35:27 +0300 Subject: [PATCH] feat(repo_map): add scala support (#788) Co-authored-by: Aaron Pham --- Cargo.lock | 11 ++++ crates/avante-repo-map/Cargo.toml | 1 + .../queries/tree-sitter-scala-defs.scm | 51 ++++++++++++++++ crates/avante-repo-map/src/lib.rs | 58 ++++++++++++++++++- 4 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 crates/avante-repo-map/queries/tree-sitter-scala-defs.scm diff --git a/Cargo.lock b/Cargo.lock index 49fb531..c748bff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,7 @@ dependencies = [ "tree-sitter-python", "tree-sitter-ruby", "tree-sitter-rust", + "tree-sitter-scala", "tree-sitter-typescript", "tree-sitter-zig", ] @@ -1405,6 +1406,16 @@ dependencies = [ "tree-sitter-language", ] +[[package]] +name = "tree-sitter-scala" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7394987e126e3b36dc94a89e48544bea8542db66a62532f6d19930685cc1230" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "tree-sitter-typescript" version = "0.23.0" diff --git a/crates/avante-repo-map/Cargo.toml b/crates/avante-repo-map/Cargo.toml index 2e2e298..49a36f1 100644 --- a/crates/avante-repo-map/Cargo.toml +++ b/crates/avante-repo-map/Cargo.toml @@ -27,6 +27,7 @@ tree-sitter-cpp = "0.23" tree-sitter-lua = "0.2" tree-sitter-ruby = "0.23" tree-sitter-zig = "1.0.2" +tree-sitter-scala = "0.23" [lints] workspace = true diff --git a/crates/avante-repo-map/queries/tree-sitter-scala-defs.scm b/crates/avante-repo-map/queries/tree-sitter-scala-defs.scm new file mode 100644 index 0000000..f28c855 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-scala-defs.scm @@ -0,0 +1,51 @@ +(class_definition + name: (identifier) @class) + +(object_definition + name: (identifier) @class) + +(trait_definition + name: (identifier) @class) + +(simple_enum_case + name: (identifier) @enum_item) + +(full_enum_case + name: (identifier) @enum_item) + +(template_body + (function_definition) @method +) + +(template_body + (function_declaration) @method +) + +(template_body + (val_definition) @class_variable +) + +(template_body + (val_declaration) @class_variable +) + + +(template_body + (var_definition) @class_variable +) + +(template_body + (var_declaration) @class_variable +) + +(compilation_unit + (function_definition) @function +) + +(compilation_unit + (val_definition) @variable +) + +(compilation_unit + (var_definition) @variable +) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index 3027d61..69e5cc4 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -60,6 +60,7 @@ fn get_ts_language(language: &str) -> Option { "lua" => Some(tree_sitter_lua::LANGUAGE), "ruby" => Some(tree_sitter_ruby::LANGUAGE), "zig" => Some(tree_sitter_zig::LANGUAGE), + "scala" => Some(tree_sitter_scala::LANGUAGE), _ => None, } } @@ -74,6 +75,7 @@ const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm"); const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm"); const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm"); const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); +const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm"); fn get_definitions_query(language: &str) -> Result { let ts_language = get_ts_language(language); @@ -92,10 +94,11 @@ fn get_definitions_query(language: &str) -> Result { "zig" => ZIG_QUERY, "typescript" => TYPESCRIPT_QUERY, "ruby" => RUBY_QUERY, + "scala" => SCALA_QUERY, _ => return Err(format!("Unsupported language: {language}")), }; let query = Query::new(&ts_language.into(), contents) - .unwrap_or_else(|_| panic!("Failed to parse query for {language}")); + .unwrap_or_else(|e| panic!("Failed to parse query for {language}: {e}")); Ok(query) } @@ -315,6 +318,12 @@ fn extract_definitions(language: &str, source: &str) -> Result, } } } + "scala" => node + .child_by_field_name("name") + .or_else(|| node.child_by_field_name("pattern")) + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(node_text) + .to_string(), _ => node .child_by_field_name("name") .map(|n| n.utf8_text(source.as_bytes()).unwrap()) @@ -363,6 +372,14 @@ fn extract_definitions(language: &str, source: &str) -> Result, zig_find_parent_variable_declaration_name(&node, source.as_bytes()) .unwrap_or_default(); } + if language == "scala" { + if let Some(enum_node) = find_ancestor_by_type(&node, "enum_definition") { + if let Some(name_node) = enum_node.child_by_field_name("name") { + enum_name = + name_node.utf8_text(source.as_bytes()).unwrap().to_string(); + } + } + } if !enum_name.is_empty() && language == "go" && !is_first_letter_uppercase(&enum_name) @@ -1372,6 +1389,45 @@ mod tests { assert_eq!(stringified, expected); } + #[test] + fn test_scala() { + let source = r#" + object Main { + def main(args: Array[String]): Unit = { + println("Hello, World!") + } + } + + class TestClass { + val testVal: String = "test" + var testVar = 42 + + def testMethod(a: Int, b: Int): Int = { + a + b + } + } + + // braceless syntax is also supported + trait TestTrait: + def abstractMethod(x: Int): Int + def concreteMethod(y: Int): Int = y * 2 + + case class TestCaseClass(name: String, age: Int) + + enum TestEnum { + case First, Second, Third + } + + val foo: TestClass = ??? + "#; + + let definitions = extract_definitions("scala", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var foo:TestClass;class Main{func main(args: Array[String]) -> Unit;};class TestCaseClass{};class TestClass{func testMethod(a: Int, b: Int) -> Int;var testVal:String;var testVar;};class TestTrait{func abstractMethod(x: Int) -> Int;func concreteMethod(y: Int) -> Int;};enum TestEnum{First;Second;Third;};"; + assert_eq!(stringified, expected); + } + #[test] fn test_unsupported_language() { let source = "print('Hello, world!')";