diff --git a/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm b/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm index 8f0117a..57a7c31 100644 --- a/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm +++ b/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm @@ -1,16 +1,26 @@ ;; Capture top-level methods, class definitions, and methods within classes -(program - (class - (body_statement - (call) @class_call - (assignment) @class_assignment - (method) @method - ) - ) @class -) + +(class + (body_statement + (call)? @class_call + (assignment)? @class_assignment + (method)? @method + ) +) @class + (program (method) @function ) (program (assignment) @assignment ) + +(module) @module + +(module + (body_statement + (call)? @class_call + (assignment)? @class_assignment + (method)? @method + ) +) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index fd0faeb..cef8e2d 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -43,6 +43,7 @@ pub struct Variable { pub enum Definition { Func(Func), Class(Class), + Module(Class), Enum(Enum), Variable(Variable), Union(Union), @@ -158,6 +159,24 @@ fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option(node: &'a Node, source: &'a [u8]) -> bool { + let mut prev_sibling = node.prev_sibling(); + while let Some(prev_sibling_node) = prev_sibling { + if prev_sibling_node.kind() == "identifier" { + let text = prev_sibling_node.utf8_text(source).unwrap_or_default(); + if text == "private" { + return true; + } else if text == "public" || text == "protected" { + return false; + } + } else if prev_sibling_node.kind() == "class" || prev_sibling_node.kind() == "module" { + return false; + } + prev_sibling = prev_sibling_node.prev_sibling(); + } + false +} + fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option> { node.children(&mut node.walk()) .find(|child| child.kind() == child_type) @@ -234,6 +253,30 @@ fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8]) None } +fn ruby_find_parent_module_declaration_name<'a>( + node: &'a Node, + source: &'a [u8], +) -> Option { + let mut path_parts = Vec::new(); + let mut current = Some(*node); + + while let Some(current_node) = current { + if current_node.kind() == "module" || current_node.kind() == "class" { + if let Some(name_node) = current_node.child_by_field_name("name") { + path_parts.push(get_node_text(&name_node, source)); + } + } + current = current_node.parent(); + } + + if path_parts.is_empty() { + None + } else { + path_parts.reverse(); + Some(path_parts.join("::")) + } +} + fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String { node.utf8_text(source).unwrap_or_default().to_string() } @@ -301,6 +344,18 @@ fn extract_definitions(language: &str, source: &str) -> Result, }); }; + let ensure_module_def = |name: &str, class_def_map: &mut BTreeMap>| { + class_def_map.entry(name.to_string()).or_insert_with(|| { + RefCell::new(Class { + name: name.to_string(), + type_name: "module".to_string(), + methods: vec![], + properties: vec![], + visibility_modifier: None, + }) + }); + }; + let ensure_enum_def = |name: &str, enum_def_map: &mut BTreeMap>| { enum_def_map.entry(name.to_string()).or_insert_with(|| { RefCell::new(Enum { @@ -395,6 +450,19 @@ fn extract_definitions(language: &str, source: &str) -> Result, .unwrap_or(node_text) .to_string() } + "ruby" => { + let name = node + .child_by_field_name("name") + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(node_text) + .to_string(); + if *capture_name == "class" || *capture_name == "module" { + ruby_find_parent_module_declaration_name(&node, source.as_bytes()) + .unwrap_or(name) + } else { + name + } + } _ => node .child_by_field_name("name") .map(|n| n.utf8_text(source.as_bytes()).unwrap()) @@ -423,6 +491,11 @@ fn extract_definitions(language: &str, source: &str) -> Result, }; } } + "module" => { + if !name.is_empty() { + ensure_module_def(&name, &mut class_def_map); + } + } "enum_item" => { let visibility_modifier_node = find_descendant_by_type(&node, "visibility_modifier"); @@ -623,6 +696,9 @@ fn extract_definitions(language: &str, source: &str) -> Result, .and_then(|n| n.utf8_text(source.as_bytes()).ok()) .unwrap_or("") .to_string() + } else if language == "ruby" { + ruby_find_parent_module_declaration_name(&node, source.as_bytes()) + .unwrap_or_default() } else if let Some(impl_item) = impl_item_node { let impl_type_node = impl_item.child_by_field_name("type"); impl_type_node @@ -649,9 +725,17 @@ fn extract_definitions(language: &str, source: &str) -> Result, let accessibility_modifier_node = find_descendant_by_type(&node, "accessibility_modifier"); - let accessibility_modifier = accessibility_modifier_node - .map(|n| n.utf8_text(source.as_bytes()).unwrap()) - .unwrap_or(""); + let accessibility_modifier = if language == "ruby" { + if ruby_method_is_private(&node, source.as_bytes()) { + "private" + } else { + "" + } + } else { + accessibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("") + }; let func = Func { name: name.to_string(), @@ -679,12 +763,17 @@ fn extract_definitions(language: &str, source: &str) -> Result, .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or(""); let value_type = get_node_type(&node, source.as_bytes()); - let class_name = get_closest_ancestor_name(&node, source); - if !class_name.is_empty() - && language == "go" - && !is_first_letter_uppercase(&class_name) - { - continue; + let mut class_name = get_closest_ancestor_name(&node, source); + if !class_name.is_empty() { + if language == "ruby" { + if let Some(namespaced_name) = + ruby_find_parent_module_declaration_name(&node, source.as_bytes()) + { + class_name = namespaced_name; + } + } else if language == "go" && !is_first_letter_uppercase(&class_name) { + continue; + } } if class_name.is_empty() { continue; @@ -1057,6 +1146,7 @@ fn stringify_definitions(definitions: &Vec) -> String { for definition in definitions { match definition { Definition::Class(class) => res = format!("{res}{}", stringify_class(class)), + Definition::Module(module) => res = format!("{res}{}", stringify_class(module)), Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)), Definition::Union(union_def) => res = format!("{res}{}", stringify_union(union_def)), Definition::Func(func) => res = format!("{res}{}", stringify_function(func)), @@ -1434,7 +1524,62 @@ mod tests { let stringified = stringify_definitions(&definitions); println!("{stringified}"); // FIXME: - let expected = "var test_var;func test_func(a, b) -> void;"; + let expected = "var test_var;func test_func(a, b) -> void;class InnerClassInFunc{func initialize(a, b) -> void;func test_method(a, b) -> void;};class TestClass{func initialize(a, b) -> void;func test_method(a, b) -> void;};"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_ruby2() { + let source = r#" + # frozen_string_literal: true + + require('jwt') + + top_level_var = 1 + + def top_level_func + inner_var_in_func = 2 + end + + module A + module B + @module_var = :foo + + def module_method + @module_var + end + + class C < Base + TEST_CONST = 1 + @class_var = :bar + attr_accessor :a, :b + + def initialize(a, b) + @a = a + @b = b + super + end + + def bar + inner_var_in_method = 1 + true + end + + private + + def baz(request, params) + auth_header = request.headers['Authorization'] + parts = auth_header.try(:split, /\s+/) + JWT.decode(parts.last) + end + end + end + end + "#; + let definitions = extract_definitions("ruby", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var top_level_var;func top_level_func() -> void;module A{};module A::B{func module_method() -> void;var @module_var;};class A::B::C{func initialize(a, b) -> void;func bar() -> void;private func baz(request, params) -> void;var TEST_CONST;var @class_var;};"; assert_eq!(stringified, expected); }