ragfs_chunker/
code.rs

1//! Code-aware chunking strategy.
2//!
3//! Chunks code at function/class/method boundaries using pattern matching.
4//! Supports Rust, Python, JavaScript, TypeScript, Go, Java, and C/C++.
5
6use async_trait::async_trait;
7use ragfs_core::{
8    ChunkConfig, ChunkError, ChunkOutput, ChunkOutputMetadata, Chunker, ContentType,
9    ExtractedContent,
10};
11use tracing::debug;
12
13/// Code-aware chunker that splits at function/class boundaries.
14pub struct CodeChunker;
15
16impl CodeChunker {
17    /// Create a new code chunker.
18    #[must_use]
19    pub fn new() -> Self {
20        Self
21    }
22
23    /// Detect language from metadata or infer from content.
24    fn detect_language(content: &ExtractedContent) -> Option<Language> {
25        content
26            .metadata
27            .language
28            .as_ref()
29            .and_then(|l| Language::from_extension(l))
30    }
31}
32
33impl Default for CodeChunker {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39#[async_trait]
40impl Chunker for CodeChunker {
41    fn name(&self) -> &'static str {
42        "code"
43    }
44
45    fn content_types(&self) -> &[&str] {
46        &["code"]
47    }
48
49    fn can_chunk(&self, content_type: &ContentType) -> bool {
50        matches!(content_type, ContentType::Code { .. })
51    }
52
53    async fn chunk(
54        &self,
55        content: &ExtractedContent,
56        config: &ChunkConfig,
57    ) -> Result<Vec<ChunkOutput>, ChunkError> {
58        let text = &content.text;
59        if text.is_empty() {
60            return Ok(vec![]);
61        }
62
63        let language = Self::detect_language(content);
64        debug!("Code chunking with language: {:?}", language);
65
66        let lines: Vec<&str> = text.lines().collect();
67        let boundaries = find_code_boundaries(&lines, language.as_ref());
68
69        // If no boundaries found, fall back to line-based chunking
70        if boundaries.is_empty() {
71            return chunk_by_lines(text, &lines, config, &content.metadata.language);
72        }
73
74        create_chunks_from_boundaries(
75            text,
76            &lines,
77            &boundaries,
78            config,
79            &content.metadata.language,
80        )
81    }
82}
83
84/// Supported programming languages.
85#[derive(Debug, Clone, Copy)]
86enum Language {
87    Rust,
88    Python,
89    JavaScript,
90    TypeScript,
91    Go,
92    Java,
93    C,
94    Cpp,
95}
96
97impl Language {
98    fn from_extension(ext: &str) -> Option<Self> {
99        match ext.to_lowercase().as_str() {
100            "rs" => Some(Language::Rust),
101            "py" => Some(Language::Python),
102            "js" | "jsx" | "mjs" => Some(Language::JavaScript),
103            "ts" | "tsx" => Some(Language::TypeScript),
104            "go" => Some(Language::Go),
105            "java" => Some(Language::Java),
106            "c" | "h" => Some(Language::C),
107            "cpp" | "cc" | "cxx" | "hpp" | "hxx" => Some(Language::Cpp),
108            _ => None,
109        }
110    }
111}
112
113/// A code boundary (start of a function/class/method).
114#[derive(Debug)]
115struct CodeBoundary {
116    line: usize,
117    kind: BoundaryKind,
118    name: Option<String>,
119}
120
121#[derive(Debug, Clone, Copy)]
122enum BoundaryKind {
123    Function,
124    Method,
125    Class,
126    Struct,
127    Enum,
128    Impl,
129    Module,
130}
131
132impl BoundaryKind {
133    fn as_str(&self) -> &'static str {
134        match self {
135            BoundaryKind::Function => "function",
136            BoundaryKind::Method => "method",
137            BoundaryKind::Class => "class",
138            BoundaryKind::Struct => "struct",
139            BoundaryKind::Enum => "enum",
140            BoundaryKind::Impl => "impl",
141            BoundaryKind::Module => "module",
142        }
143    }
144}
145
146/// Find code boundaries (functions, classes, etc.) in source code.
147fn find_code_boundaries(lines: &[&str], language: Option<&Language>) -> Vec<CodeBoundary> {
148    let mut boundaries = Vec::new();
149
150    for (i, line) in lines.iter().enumerate() {
151        let trimmed = line.trim();
152
153        // Skip empty lines and comments
154        if trimmed.is_empty() || trimmed.starts_with("//") || trimmed.starts_with('#') {
155            continue;
156        }
157
158        if let Some(boundary) = detect_boundary(trimmed, language) {
159            boundaries.push(CodeBoundary {
160                line: i,
161                kind: boundary.0,
162                name: boundary.1,
163            });
164        }
165    }
166
167    boundaries
168}
169
170/// Detect if a line starts a code boundary.
171fn detect_boundary(
172    line: &str,
173    language: Option<&Language>,
174) -> Option<(BoundaryKind, Option<String>)> {
175    // Language-specific patterns
176    match language {
177        Some(Language::Rust) => detect_rust_boundary(line),
178        Some(Language::Python) => detect_python_boundary(line),
179        Some(Language::JavaScript | Language::TypeScript) => detect_js_boundary(line),
180        Some(Language::Go) => detect_go_boundary(line),
181        Some(Language::Java) => detect_java_boundary(line),
182        Some(Language::C | Language::Cpp) => detect_c_boundary(line),
183        None => {
184            // Try all patterns
185            detect_rust_boundary(line)
186                .or_else(|| detect_python_boundary(line))
187                .or_else(|| detect_js_boundary(line))
188                .or_else(|| detect_go_boundary(line))
189                .or_else(|| detect_java_boundary(line))
190                .or_else(|| detect_c_boundary(line))
191        }
192    }
193}
194
195/// Detect Rust code boundaries.
196fn detect_rust_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
197    // fn name(
198    if line.starts_with("pub fn ")
199        || line.starts_with("fn ")
200        || line.starts_with("pub async fn ")
201        || line.starts_with("async fn ")
202        || line.starts_with("pub(crate) fn ")
203        || line.starts_with("pub(super) fn ")
204    {
205        let name = extract_rust_fn_name(line);
206        return Some((BoundaryKind::Function, name));
207    }
208
209    // impl block
210    if line.starts_with("impl ") || line.starts_with("impl<") {
211        let name = extract_after_keyword(line, "impl");
212        return Some((BoundaryKind::Impl, name));
213    }
214
215    // struct
216    if line.starts_with("pub struct ")
217        || line.starts_with("struct ")
218        || line.starts_with("pub(crate) struct ")
219    {
220        let name = extract_after_keyword(line, "struct");
221        return Some((BoundaryKind::Struct, name));
222    }
223
224    // enum
225    if line.starts_with("pub enum ")
226        || line.starts_with("enum ")
227        || line.starts_with("pub(crate) enum ")
228    {
229        let name = extract_after_keyword(line, "enum");
230        return Some((BoundaryKind::Enum, name));
231    }
232
233    // mod
234    if line.starts_with("pub mod ") || line.starts_with("mod ") {
235        let name = extract_after_keyword(line, "mod");
236        return Some((BoundaryKind::Module, name));
237    }
238
239    None
240}
241
242/// Detect Python code boundaries.
243fn detect_python_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
244    // def name(
245    if line.starts_with("def ") || line.starts_with("async def ") {
246        let name = extract_python_fn_name(line);
247        return Some((BoundaryKind::Function, name));
248    }
249
250    // class Name
251    if line.starts_with("class ") {
252        let name = extract_after_keyword(line, "class");
253        return Some((BoundaryKind::Class, name));
254    }
255
256    None
257}
258
259/// Detect JavaScript/TypeScript code boundaries.
260fn detect_js_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
261    // function name(
262    if line.starts_with("function ")
263        || line.starts_with("async function ")
264        || line.starts_with("export function ")
265        || line.starts_with("export async function ")
266    {
267        let name = extract_js_fn_name(line);
268        return Some((BoundaryKind::Function, name));
269    }
270
271    // const name = (
272    if (line.starts_with("const ") || line.starts_with("export const "))
273        && (line.contains(" = (") || line.contains(" = async ("))
274    {
275        let name = extract_const_fn_name(line);
276        return Some((BoundaryKind::Function, name));
277    }
278
279    // class Name
280    if line.starts_with("class ")
281        || line.starts_with("export class ")
282        || line.starts_with("export default class ")
283    {
284        let name = extract_after_keyword(line, "class");
285        return Some((BoundaryKind::Class, name));
286    }
287
288    // interface Name (TypeScript)
289    if line.starts_with("interface ") || line.starts_with("export interface ") {
290        let name = extract_after_keyword(line, "interface");
291        return Some((BoundaryKind::Struct, name));
292    }
293
294    // type Name (TypeScript)
295    if line.starts_with("type ") || line.starts_with("export type ") {
296        let name = extract_after_keyword(line, "type");
297        return Some((BoundaryKind::Struct, name));
298    }
299
300    None
301}
302
303/// Detect Go code boundaries.
304fn detect_go_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
305    // func name(
306    if line.starts_with("func ") {
307        // Check if it's a method (has receiver)
308        if line.contains(") ") && line.find('(') < line.find(')') {
309            let name = extract_go_method_name(line);
310            return Some((BoundaryKind::Method, name));
311        }
312        let name = extract_go_fn_name(line);
313        return Some((BoundaryKind::Function, name));
314    }
315
316    // type Name struct
317    if line.starts_with("type ") && line.contains(" struct") {
318        let name = extract_after_keyword(line, "type");
319        return Some((BoundaryKind::Struct, name));
320    }
321
322    // type Name interface
323    if line.starts_with("type ") && line.contains(" interface") {
324        let name = extract_after_keyword(line, "type");
325        return Some((BoundaryKind::Struct, name));
326    }
327
328    None
329}
330
331/// Detect Java code boundaries.
332fn detect_java_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
333    // public/private/protected class Name
334    if line.contains(" class ") || line.starts_with("class ") {
335        let name = extract_after_keyword(line, "class");
336        return Some((BoundaryKind::Class, name));
337    }
338
339    // interface
340    if line.contains(" interface ") || line.starts_with("interface ") {
341        let name = extract_after_keyword(line, "interface");
342        return Some((BoundaryKind::Struct, name));
343    }
344
345    // enum
346    if line.contains(" enum ") || line.starts_with("enum ") {
347        let name = extract_after_keyword(line, "enum");
348        return Some((BoundaryKind::Enum, name));
349    }
350
351    // Method detection (simplified - after class/interface)
352    if (line.contains("public ") || line.contains("private ") || line.contains("protected "))
353        && line.contains('(')
354        && !line.contains(" class ")
355        && !line.contains(" interface ")
356    {
357        let name = extract_java_method_name(line);
358        return Some((BoundaryKind::Method, name));
359    }
360
361    None
362}
363
364/// Detect C/C++ code boundaries.
365fn detect_c_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
366    // class Name (C++)
367    if line.starts_with("class ") || line.contains(" class ") {
368        let name = extract_after_keyword(line, "class");
369        return Some((BoundaryKind::Class, name));
370    }
371
372    // struct Name
373    if line.starts_with("struct ") || line.contains(" struct ") {
374        let name = extract_after_keyword(line, "struct");
375        return Some((BoundaryKind::Struct, name));
376    }
377
378    // enum
379    if line.starts_with("enum ") || line.contains(" enum ") {
380        let name = extract_after_keyword(line, "enum");
381        return Some((BoundaryKind::Enum, name));
382    }
383
384    // Function (simplified - type name( pattern)
385    if line.contains('(') && !line.starts_with('#') && !line.starts_with("//") {
386        // Very basic function detection
387        let name = extract_c_fn_name(line);
388        if name.is_some() {
389            return Some((BoundaryKind::Function, name));
390        }
391    }
392
393    None
394}
395
396// Helper functions to extract names
397
398fn extract_rust_fn_name(line: &str) -> Option<String> {
399    let line = line
400        .trim_start_matches("pub ")
401        .trim_start_matches("pub(crate) ")
402        .trim_start_matches("pub(super) ")
403        .trim_start_matches("async ")
404        .trim_start_matches("fn ");
405    line.split('(')
406        .next()
407        .map(|s| s.split('<').next().unwrap_or(s).trim().to_string())
408        .filter(|s| !s.is_empty())
409}
410
411fn extract_python_fn_name(line: &str) -> Option<String> {
412    let line = line.trim_start_matches("async ").trim_start_matches("def ");
413    line.split('(')
414        .next()
415        .map(|s| s.trim().to_string())
416        .filter(|s| !s.is_empty())
417}
418
419fn extract_js_fn_name(line: &str) -> Option<String> {
420    let line = line
421        .trim_start_matches("export ")
422        .trim_start_matches("async ")
423        .trim_start_matches("function ");
424    line.split('(')
425        .next()
426        .map(|s| s.trim().to_string())
427        .filter(|s| !s.is_empty())
428}
429
430fn extract_const_fn_name(line: &str) -> Option<String> {
431    let line = line
432        .trim_start_matches("export ")
433        .trim_start_matches("const ");
434    line.split(" =")
435        .next()
436        .map(|s| s.trim().to_string())
437        .filter(|s| !s.is_empty())
438}
439
440fn extract_go_fn_name(line: &str) -> Option<String> {
441    let line = line.trim_start_matches("func ");
442    line.split('(')
443        .next()
444        .map(|s| s.trim().to_string())
445        .filter(|s| !s.is_empty())
446}
447
448fn extract_go_method_name(line: &str) -> Option<String> {
449    // func (r *Receiver) MethodName(
450    let line = line.trim_start_matches("func ");
451    if let Some(idx) = line.find(") ") {
452        let after_receiver = &line[idx + 2..];
453        return after_receiver
454            .split('(')
455            .next()
456            .map(|s| s.trim().to_string())
457            .filter(|s| !s.is_empty());
458    }
459    None
460}
461
462fn extract_java_method_name(line: &str) -> Option<String> {
463    // Find the part before (
464    if let Some(paren_idx) = line.find('(') {
465        let before_paren = &line[..paren_idx];
466        // Last word before ( is the method name
467        before_paren
468            .split_whitespace()
469            .last()
470            .map(std::string::ToString::to_string)
471    } else {
472        None
473    }
474}
475
476fn extract_c_fn_name(line: &str) -> Option<String> {
477    // Very simplified: look for word before (
478    if let Some(paren_idx) = line.find('(') {
479        let before_paren = line[..paren_idx].trim();
480        // Skip control flow keywords
481        let last_word = before_paren.split_whitespace().last()?;
482        if ["if", "while", "for", "switch", "return", "sizeof", "typeof"]
483            .contains(&last_word.to_lowercase().as_str())
484        {
485            return None;
486        }
487        // Skip pointer/ref symbols
488        let name = last_word.trim_start_matches('*').trim_start_matches('&');
489        if !name.is_empty() && name.chars().all(|c| c.is_alphanumeric() || c == '_') {
490            return Some(name.to_string());
491        }
492    }
493    None
494}
495
496fn extract_after_keyword(line: &str, keyword: &str) -> Option<String> {
497    if let Some(idx) = line.find(keyword) {
498        let after = &line[idx + keyword.len()..];
499        let after = after.trim_start();
500        // Handle generics like impl<T>
501        let name = if after.starts_with('<') {
502            after.split('>').nth(1).unwrap_or(after).trim_start()
503        } else {
504            after
505        };
506        // Extract until space, {, (, or <
507        let name = name
508            .split(|c: char| c.is_whitespace() || c == '{' || c == '(' || c == '<' || c == ':')
509            .next()
510            .map(std::string::ToString::to_string)
511            .filter(|s| !s.is_empty());
512        return name;
513    }
514    None
515}
516
517/// Create chunks from detected boundaries.
518fn create_chunks_from_boundaries(
519    text: &str,
520    lines: &[&str],
521    boundaries: &[CodeBoundary],
522    config: &ChunkConfig,
523    language: &Option<String>,
524) -> Result<Vec<ChunkOutput>, ChunkError> {
525    let mut chunks = Vec::new();
526    let chars_per_token = 4;
527    let max_chars = config.max_size * chars_per_token;
528
529    for (i, boundary) in boundaries.iter().enumerate() {
530        let start_line = boundary.line;
531        let end_line = if i + 1 < boundaries.len() {
532            boundaries[i + 1].line
533        } else {
534            lines.len()
535        };
536
537        // Skip if this section is too small
538        if end_line <= start_line {
539            continue;
540        }
541
542        // Get content for this boundary
543        let chunk_lines = &lines[start_line..end_line];
544        let content: String = chunk_lines.join("\n");
545
546        // If content is too large, split it
547        if content.len() > max_chars {
548            let sub_chunks = split_large_chunk(
549                &content,
550                start_line,
551                config,
552                language,
553                &boundary.kind,
554                &boundary.name,
555            )?;
556            chunks.extend(sub_chunks);
557        } else {
558            let (byte_start, byte_end) = calculate_byte_range(text, lines, start_line, end_line);
559
560            chunks.push(ChunkOutput {
561                content,
562                byte_range: byte_start..byte_end,
563                line_range: Some(start_line as u32..end_line as u32),
564                parent_index: None,
565                depth: 0,
566                metadata: ChunkOutputMetadata {
567                    symbol_type: Some(boundary.kind.as_str().to_string()),
568                    symbol_name: boundary.name.clone(),
569                    language: language.clone(),
570                },
571            });
572        }
573    }
574
575    // If no chunks created, fall back to line-based chunking
576    if chunks.is_empty() {
577        return chunk_by_lines(text, lines, config, language);
578    }
579
580    Ok(chunks)
581}
582
583/// Split a chunk that's too large.
584fn split_large_chunk(
585    content: &str,
586    base_line: usize,
587    config: &ChunkConfig,
588    language: &Option<String>,
589    kind: &BoundaryKind,
590    name: &Option<String>,
591) -> Result<Vec<ChunkOutput>, ChunkError> {
592    let mut chunks = Vec::new();
593    let chars_per_token = 4;
594    let target_chars = config.target_size * chars_per_token;
595    let overlap_chars = config.overlap * chars_per_token;
596    let step = target_chars.saturating_sub(overlap_chars).max(1);
597
598    let lines: Vec<&str> = content.lines().collect();
599    let mut start = 0;
600
601    while start < lines.len() {
602        let mut char_count = 0;
603        let mut end = start;
604
605        // Find end based on character count
606        while end < lines.len() && char_count < target_chars {
607            char_count += lines[end].len() + 1; // +1 for newline
608            end += 1;
609        }
610
611        let chunk_content: String = lines[start..end].join("\n");
612        let chunk_lines = end - start;
613
614        // Calculate byte range within content
615        let byte_start = lines[..start].iter().map(|l| l.len() + 1).sum::<usize>() as u64;
616        let byte_end = byte_start + chunk_content.len() as u64;
617
618        chunks.push(ChunkOutput {
619            content: chunk_content,
620            byte_range: byte_start..byte_end,
621            line_range: Some((base_line + start) as u32..(base_line + end) as u32),
622            parent_index: None,
623            depth: 0,
624            metadata: ChunkOutputMetadata {
625                symbol_type: Some(kind.as_str().to_string()),
626                symbol_name: name.clone(),
627                language: language.clone(),
628            },
629        });
630
631        // Move start forward
632        let line_step = (step / (char_count / chunk_lines.max(1))).max(1);
633        start += line_step;
634
635        if end >= lines.len() {
636            break;
637        }
638    }
639
640    Ok(chunks)
641}
642
643/// Chunk by lines when no code boundaries are found.
644fn chunk_by_lines(
645    text: &str,
646    lines: &[&str],
647    config: &ChunkConfig,
648    language: &Option<String>,
649) -> Result<Vec<ChunkOutput>, ChunkError> {
650    let mut chunks = Vec::new();
651    let chars_per_token = 4;
652    let target_chars = config.target_size * chars_per_token;
653    let overlap_lines = (config.overlap * chars_per_token) / 80; // Assuming ~80 chars per line
654
655    let mut start = 0;
656    while start < lines.len() {
657        let mut char_count = 0;
658        let mut end = start;
659
660        while end < lines.len() && char_count < target_chars {
661            char_count += lines[end].len() + 1;
662            end += 1;
663        }
664
665        let chunk_content: String = lines[start..end].join("\n");
666        let (byte_start, byte_end) = calculate_byte_range(text, lines, start, end);
667
668        chunks.push(ChunkOutput {
669            content: chunk_content,
670            byte_range: byte_start..byte_end,
671            line_range: Some(start as u32..end as u32),
672            parent_index: None,
673            depth: 0,
674            metadata: ChunkOutputMetadata {
675                language: language.clone(),
676                ..Default::default()
677            },
678        });
679
680        start = (end).saturating_sub(overlap_lines).max(start + 1);
681    }
682
683    Ok(chunks)
684}
685
686/// Calculate byte range for a line range.
687fn calculate_byte_range(
688    text: &str,
689    lines: &[&str],
690    start_line: usize,
691    end_line: usize,
692) -> (u64, u64) {
693    let byte_start: usize = lines[..start_line].iter().map(|l| l.len() + 1).sum();
694    let byte_end: usize = lines[..end_line].iter().map(|l| l.len() + 1).sum();
695    (byte_start as u64, byte_end.min(text.len()) as u64)
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    #[test]
703    fn test_detect_rust_boundary() {
704        assert!(detect_rust_boundary("fn main() {").is_some());
705        assert!(detect_rust_boundary("pub fn new() -> Self {").is_some());
706        assert!(detect_rust_boundary("pub async fn process() {").is_some());
707        assert!(detect_rust_boundary("impl Foo {").is_some());
708        assert!(detect_rust_boundary("pub struct Bar {").is_some());
709        assert!(detect_rust_boundary("enum Baz {").is_some());
710        assert!(detect_rust_boundary("let x = 5;").is_none());
711    }
712
713    #[test]
714    fn test_detect_python_boundary() {
715        assert!(detect_python_boundary("def hello():").is_some());
716        assert!(detect_python_boundary("async def world():").is_some());
717        assert!(detect_python_boundary("class MyClass:").is_some());
718        assert!(detect_python_boundary("    x = 5").is_none());
719    }
720
721    #[test]
722    fn test_detect_js_boundary() {
723        assert!(detect_js_boundary("function foo() {").is_some());
724        assert!(detect_js_boundary("async function bar() {").is_some());
725        assert!(detect_js_boundary("export function baz() {").is_some());
726        assert!(detect_js_boundary("const fn = () => {").is_some());
727        assert!(detect_js_boundary("class Component {").is_some());
728        assert!(detect_js_boundary("export interface Props {").is_some());
729    }
730
731    #[test]
732    fn test_extract_rust_fn_name() {
733        assert_eq!(
734            extract_rust_fn_name("fn main() {"),
735            Some("main".to_string())
736        );
737        assert_eq!(
738            extract_rust_fn_name("pub fn new() -> Self {"),
739            Some("new".to_string())
740        );
741        assert_eq!(
742            extract_rust_fn_name("pub async fn process<T>() {"),
743            Some("process".to_string())
744        );
745    }
746
747    #[test]
748    fn test_language_detection() {
749        assert!(matches!(
750            Language::from_extension("rs"),
751            Some(Language::Rust)
752        ));
753        assert!(matches!(
754            Language::from_extension("py"),
755            Some(Language::Python)
756        ));
757        assert!(matches!(
758            Language::from_extension("js"),
759            Some(Language::JavaScript)
760        ));
761        assert!(Language::from_extension("unknown").is_none());
762    }
763}