From 9de95f9e021c0f1c5f200f8fcb46030f38deacc8 Mon Sep 17 00:00:00 2001 From: Maddison Hellstrom Date: Mon, 21 Oct 2024 02:20:18 -0700 Subject: [PATCH] feat(repo-map): C++ improvements (#734) --- .../queries/tree-sitter-cpp-defs.scm | 40 +++- crates/avante-repo-map/src/lib.rs | 223 ++++++++++++++++-- 2 files changed, 227 insertions(+), 36 deletions(-) diff --git a/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm b/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm index 0e1a485..4a73740 100644 --- a/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm +++ b/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm @@ -1,11 +1,31 @@ -;; Capture extern functions, variables, public classes, and methods -(function_definition - (storage_class_specifier) @extern -) @function +;; Capture functions, variables, nammespaces, classes, methods, and enums +(namespace_definition) @namespace +(function_definition) @function +(class_specifier) @class (class_specifier - (public) @class - (function_definition) @method -) @class -(declaration - (storage_class_specifier) @extern -) @variable + body: (field_declaration_list + (declaration + declarator: (function_declarator))? @method + (field_declaration + declarator: (function_declarator))? @method + (function_definition)? @method + (function_declarator)? @method + (field_declaration + declarator: (field_identifier))? @class_variable + ) +) +(struct_specifier) @struct +(struct_specifier + body: (field_declaration_list + (declaration + declarator: (function_declarator))? @method + (field_declaration + declarator: (function_declarator))? @method + (function_definition)? @method + (function_declarator)? @method + (field_declaration + declarator: (field_identifier))? @class_variable + ) +) +((declaration type: (_))) @variable +(enumerator_list ((enumerator) @enum_item)) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index 5ad8973..3027d61 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -1,6 +1,6 @@ use mlua::prelude::*; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::BTreeMap; use tree_sitter::{Node, Parser, Query, QueryCursor}; use tree_sitter_language::LanguageFn; @@ -45,6 +45,7 @@ pub enum Definition { Enum(Enum), Variable(Variable), Union(Union), + // TODO: Namespace support } fn get_ts_language(language: &str) -> Option { @@ -227,11 +228,11 @@ fn extract_definitions(language: &str, source: &str) -> Result, let mut query_cursor = QueryCursor::new(); let captures = query_cursor.captures(&query, root_node, source.as_bytes()); - let mut class_def_map: HashMap> = HashMap::new(); - let mut enum_def_map: HashMap> = HashMap::new(); - let mut union_def_map: HashMap> = HashMap::new(); + let mut class_def_map: BTreeMap> = BTreeMap::new(); + let mut enum_def_map: BTreeMap> = BTreeMap::new(); + let mut union_def_map: BTreeMap> = BTreeMap::new(); - let ensure_class_def = |name: &str, class_def_map: &mut HashMap>| { + let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap>| { class_def_map.entry(name.to_string()).or_insert_with(|| { RefCell::new(Class { name: name.to_string(), @@ -242,7 +243,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, }); }; - let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap>| { + let ensure_enum_def = |name: &str, enum_def_map: &mut BTreeMap>| { enum_def_map.entry(name.to_string()).or_insert_with(|| { RefCell::new(Enum { name: name.to_string(), @@ -251,7 +252,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, }); }; - let ensure_union_def = |name: &str, union_def_map: &mut HashMap>| { + let ensure_union_def = |name: &str, union_def_map: &mut BTreeMap>| { union_def_map.entry(name.to_string()).or_insert_with(|| { RefCell::new(Union { name: name.to_string(), @@ -260,30 +261,80 @@ fn extract_definitions(language: &str, source: &str) -> Result, }); }; + // Sometimes, multiple queries capture the same node with the same capture name. + // We need to ensure that we only add the node to the definition map once. + let mut captured_nodes: BTreeMap> = BTreeMap::new(); + for (m, _) in captures { for capture in m.captures { let capture_name = &query.capture_names()[capture.index as usize]; let node = capture.node; let node_text = node.utf8_text(source.as_bytes()).unwrap(); - let name_node = node.child_by_field_name("name"); - let name = name_node - .map(|n| n.utf8_text(source.as_bytes()).unwrap()) - .unwrap_or(node_text); + let node_id = node.id(); + if captured_nodes + .get(*capture_name) + .map_or(false, |v| v.contains(&node_id)) + { + continue; + } + captured_nodes + .entry(String::from(*capture_name)) + .or_default() + .push(node_id); + + let name = match language { + "cpp" => { + if *capture_name == "class" { + node.child_by_field_name("name") + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(node_text) + .to_string() + } else { + let ident = find_descendant_by_type(&node, "field_identifier") + .or_else(|| find_descendant_by_type(&node, "operator_name")) + .or_else(|| find_descendant_by_type(&node, "identifier")) + .map(|n| n.utf8_text(source.as_bytes()).unwrap()); + if let Some(ident) = ident { + let scope = node + .child_by_field_name("declarator") + .and_then(|n| n.child_by_field_name("declarator")) + .and_then(|n| n.child_by_field_name("scope")); + + if let Some(scope_node) = scope { + format!( + "{}::{}", + scope_node.utf8_text(source.as_bytes()).unwrap(), + ident + ) + } else { + ident.to_string() + } + } else { + node_text.to_string() + } + } + } + _ => node + .child_by_field_name("name") + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(node_text) + .to_string(), + }; match *capture_name { "class" => { if !name.is_empty() { - if language == "go" && !is_first_letter_uppercase(name) { + if language == "go" && !is_first_letter_uppercase(&name) { continue; } - ensure_class_def(name, &mut class_def_map); + ensure_class_def(&name, &mut class_def_map); let visibility_modifier_node = find_child_by_type(&node, "visibility_modifier"); let visibility_modifier = visibility_modifier_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or(""); - let class_def = class_def_map.get_mut(name).unwrap(); + let class_def = class_def_map.get_mut(&name).unwrap(); class_def.borrow_mut().visibility_modifier = if visibility_modifier.is_empty() { None @@ -353,6 +404,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, union_def.borrow_mut().items.push(variable); } "method" => { + // TODO: C++: Skip private/protected class/struct methods let visibility_modifier_node = find_descendant_by_type(&node, "visibility_modifier"); let visibility_modifier = visibility_modifier_node @@ -367,11 +419,18 @@ fn extract_definitions(language: &str, source: &str) -> Result, { continue; } - - if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + if language == "cpp" + && find_descendant_by_type(&node, "destructor_name").is_some() + { continue; } - let mut params_node = node.child_by_field_name("parameters"); + + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { + continue; + } + let mut params_node = node + .child_by_field_name("parameters") + .or_else(|| find_descendant_by_type(&node, "parameter_list")); let function_node = find_ancestor_by_type(&node, "function_declaration"); if language == "zig" { @@ -383,7 +442,23 @@ fn extract_definitions(language: &str, source: &str) -> Result, let params = params_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or("()"); - let mut return_type_node = node.child_by_field_name("return_type"); + let mut return_type_node = match language { + "cpp" => node.child_by_field_name("type"), + _ => node.child_by_field_name("return_type"), + }; + if language == "cpp" { + let class_specifier_node = find_ancestor_by_type(&node, "class_specifier"); + let type_identifier_node = + class_specifier_node.and_then(|n| n.child_by_field_name("name")); + + if let Some(type_identifier_node) = type_identifier_node { + let type_identifier_text = + type_identifier_node.utf8_text(source.as_bytes()).unwrap(); + if name == type_identifier_text { + return_type_node = Some(type_identifier_node); + } + } + } if return_type_node.is_none() { return_type_node = node.child_by_field_name("result"); } @@ -404,6 +479,13 @@ fn extract_definitions(language: &str, source: &str) -> Result, let class_name = if language == "zig" { zig_find_parent_variable_declaration_name(&node, source.as_bytes()) .unwrap_or_default() + } else if language == "cpp" { + find_ancestor_by_type(&node, "class_specifier") + .or_else(|| find_ancestor_by_type(&node, "struct_specifier")) + .and_then(|n| n.child_by_field_name("name")) + .and_then(|n| n.utf8_text(source.as_bytes()).ok()) + .unwrap_or("") + .to_string() } else if let Some(impl_item) = impl_item_node { let impl_type_node = impl_item.child_by_field_name("type"); impl_type_node @@ -479,6 +561,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, class_def.borrow_mut().properties.push(variable); } "class_variable" => { + // TODO: C++: Skip private/protected class/struct variables let visibility_modifier_node = find_descendant_by_type(&node, "visibility_modifier"); let visibility_modifier = visibility_modifier_node @@ -498,6 +581,14 @@ fn extract_definitions(language: &str, source: &str) -> Result, } let mut class_name = get_closest_ancestor_name(&node, source); + if language == "cpp" { + class_name = find_ancestor_by_type(&node, "class_specifier") + .or_else(|| find_ancestor_by_type(&node, "struct_specifier")) + .and_then(|n| n.child_by_field_name("name")) + .and_then(|n| n.utf8_text(source.as_bytes()).ok()) + .unwrap_or("") + .to_string(); + } if language == "zig" { class_name = zig_find_parent_variable_declaration_name(&node, source.as_bytes()) @@ -512,7 +603,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, if class_name.is_empty() { continue; } - if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { continue; } ensure_class_def(&class_name, &mut class_def_map); @@ -542,27 +633,40 @@ fn extract_definitions(language: &str, source: &str) -> Result, } } - if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { continue; } let impl_item_node = find_ancestor_by_type(&node, "impl_item"); if impl_item_node.is_some() { continue; } + let class_specifier_node = find_ancestor_by_type(&node, "class_specifier"); + if class_specifier_node.is_some() { + continue; + } + let struct_specifier_node = find_ancestor_by_type(&node, "struct_specifier"); + if struct_specifier_node.is_some() { + continue; + } let function_node = find_ancestor_by_type(&node, "function_declaration") .or_else(|| find_ancestor_by_type(&node, "function_definition")); if function_node.is_some() { continue; } - let params_node = node.child_by_field_name("parameters"); + let params_node = node + .child_by_field_name("parameters") + .or_else(|| find_descendant_by_type(&node, "parameter_list")); let params = params_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or("()"); - let mut return_type_node = node.child_by_field_name("return_type"); - if return_type_node.is_none() { - return_type_node = node.child_by_field_name("result"); - } + let mut return_type = "void".to_string(); + let return_type_node = match language { + "cpp" => node.child_by_field_name("type"), + _ => node + .child_by_field_name("return_type") + .or_else(|| node.child_by_field_name("result")), + }; if return_type_node.is_some() { return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes()); if return_type.is_empty() { @@ -689,7 +793,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, continue; }; } - if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { continue; } let variable = Variable { @@ -836,6 +940,7 @@ fn avante_repo_map(lua: &Lua) -> LuaResult { mod tests { use super::*; + #[test] fn test_rust() { let source = r#" // This is a test comment @@ -1201,6 +1306,72 @@ mod tests { assert_eq!(stringified, expected); } + #[test] + fn test_cpp() { + let source = r#" + // This is a test comment + #include + + namespace { + constexpr int TEST_CONSTEXPR = 1; + const int TEST_CONST = 1; + }; // namespace + + int test_var = 2; + + int TestFunc(bool b) { return b ? 42 : -1; } + + template class TestClass { + public: + TestClass(); + TestClass(T a, T b); + ~TestClass(); + bool operator==(const TestClass &other); + T testMethod(T x, T y) { return x + y; } + T c; + + private: + void privateMethod(); + T a = 0; + T b; + }; + + struct TestStruct { + public: + TestStruct(int a, int b); + ~TestStruct(); + bool operator==(const TestStruct &other); + int testMethod(int x, int y) { return x + y; } + static int c; + + private: + int a = 0; + int b; + }; + + bool TestStruct::operator==(const TestStruct &other) { return true; } + + int TestStruct::c = 0; + + int testFunction(int a, int b) { return a + b; } + + namespace TestNamespace { + class InnerClass { + public: + bool innerMethod(int a) const; + }; + bool InnerClass::innerMethod(int a) const { return doSomething(a * 2); } + } // namespace TestNamespace + + enum TestEnum { ENUM_VALUE_1, ENUM_VALUE_2 }; + "#; + let definitions = extract_definitions("cpp", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{}", stringified); + let expected = "var TEST_CONSTEXPR:int;var TEST_CONST:int;var test_var:int;func TestFunc(bool b) -> int;func TestStruct::operator==(const TestStruct &other) -> bool;var TestStruct::c:int;func testFunction(int a, int b) -> int;func InnerClass::innerMethod(int a) -> bool;class InnerClass{func innerMethod(int a) -> bool;};class TestClass{func TestClass() -> TestClass;func operator==(const TestClass &other) -> bool;func testMethod(T x, T y) -> T;func privateMethod() -> void;func TestClass(T a, T b) -> TestClass;var c:T;var a:T;var b:T;};class TestStruct{func TestStruct(int a, int b) -> void;func operator==(const TestStruct &other) -> bool;func testMethod(int x, int y) -> int;var c:int;var a:int;var b:int;};enum TestEnum{ENUM_VALUE_1;ENUM_VALUE_2;};"; + assert_eq!(stringified, expected); + } + #[test] fn test_unsupported_language() { let source = "print('Hello, world!')";