1use async_trait::async_trait;
7use ragfs_core::{
8 ChunkConfig, ChunkError, ChunkOutput, ChunkOutputMetadata, Chunker, ContentType,
9 ExtractedContent,
10};
11use tracing::debug;
12
13pub struct CodeChunker;
15
16impl CodeChunker {
17 #[must_use]
19 pub fn new() -> Self {
20 Self
21 }
22
23 fn detect_language(content: &ExtractedContent) -> Option<Language> {
25 content
26 .metadata
27 .language
28 .as_ref()
29 .and_then(|l| Language::from_extension(l))
30 }
31}
32
33impl Default for CodeChunker {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[async_trait]
40impl Chunker for CodeChunker {
41 fn name(&self) -> &'static str {
42 "code"
43 }
44
45 fn content_types(&self) -> &[&str] {
46 &["code"]
47 }
48
49 fn can_chunk(&self, content_type: &ContentType) -> bool {
50 matches!(content_type, ContentType::Code { .. })
51 }
52
53 async fn chunk(
54 &self,
55 content: &ExtractedContent,
56 config: &ChunkConfig,
57 ) -> Result<Vec<ChunkOutput>, ChunkError> {
58 let text = &content.text;
59 if text.is_empty() {
60 return Ok(vec![]);
61 }
62
63 let language = Self::detect_language(content);
64 debug!("Code chunking with language: {:?}", language);
65
66 let lines: Vec<&str> = text.lines().collect();
67 let boundaries = find_code_boundaries(&lines, language.as_ref());
68
69 if boundaries.is_empty() {
71 return chunk_by_lines(text, &lines, config, &content.metadata.language);
72 }
73
74 create_chunks_from_boundaries(
75 text,
76 &lines,
77 &boundaries,
78 config,
79 &content.metadata.language,
80 )
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
86enum Language {
87 Rust,
88 Python,
89 JavaScript,
90 TypeScript,
91 Go,
92 Java,
93 C,
94 Cpp,
95}
96
97impl Language {
98 fn from_extension(ext: &str) -> Option<Self> {
99 match ext.to_lowercase().as_str() {
100 "rs" => Some(Language::Rust),
101 "py" => Some(Language::Python),
102 "js" | "jsx" | "mjs" => Some(Language::JavaScript),
103 "ts" | "tsx" => Some(Language::TypeScript),
104 "go" => Some(Language::Go),
105 "java" => Some(Language::Java),
106 "c" | "h" => Some(Language::C),
107 "cpp" | "cc" | "cxx" | "hpp" | "hxx" => Some(Language::Cpp),
108 _ => None,
109 }
110 }
111}
112
113#[derive(Debug)]
115struct CodeBoundary {
116 line: usize,
117 kind: BoundaryKind,
118 name: Option<String>,
119}
120
121#[derive(Debug, Clone, Copy)]
122enum BoundaryKind {
123 Function,
124 Method,
125 Class,
126 Struct,
127 Enum,
128 Impl,
129 Module,
130}
131
132impl BoundaryKind {
133 fn as_str(&self) -> &'static str {
134 match self {
135 BoundaryKind::Function => "function",
136 BoundaryKind::Method => "method",
137 BoundaryKind::Class => "class",
138 BoundaryKind::Struct => "struct",
139 BoundaryKind::Enum => "enum",
140 BoundaryKind::Impl => "impl",
141 BoundaryKind::Module => "module",
142 }
143 }
144}
145
146fn find_code_boundaries(lines: &[&str], language: Option<&Language>) -> Vec<CodeBoundary> {
148 let mut boundaries = Vec::new();
149
150 for (i, line) in lines.iter().enumerate() {
151 let trimmed = line.trim();
152
153 if trimmed.is_empty() || trimmed.starts_with("//") || trimmed.starts_with('#') {
155 continue;
156 }
157
158 if let Some(boundary) = detect_boundary(trimmed, language) {
159 boundaries.push(CodeBoundary {
160 line: i,
161 kind: boundary.0,
162 name: boundary.1,
163 });
164 }
165 }
166
167 boundaries
168}
169
170fn detect_boundary(
172 line: &str,
173 language: Option<&Language>,
174) -> Option<(BoundaryKind, Option<String>)> {
175 match language {
177 Some(Language::Rust) => detect_rust_boundary(line),
178 Some(Language::Python) => detect_python_boundary(line),
179 Some(Language::JavaScript | Language::TypeScript) => detect_js_boundary(line),
180 Some(Language::Go) => detect_go_boundary(line),
181 Some(Language::Java) => detect_java_boundary(line),
182 Some(Language::C | Language::Cpp) => detect_c_boundary(line),
183 None => {
184 detect_rust_boundary(line)
186 .or_else(|| detect_python_boundary(line))
187 .or_else(|| detect_js_boundary(line))
188 .or_else(|| detect_go_boundary(line))
189 .or_else(|| detect_java_boundary(line))
190 .or_else(|| detect_c_boundary(line))
191 }
192 }
193}
194
195fn detect_rust_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
197 if line.starts_with("pub fn ")
199 || line.starts_with("fn ")
200 || line.starts_with("pub async fn ")
201 || line.starts_with("async fn ")
202 || line.starts_with("pub(crate) fn ")
203 || line.starts_with("pub(super) fn ")
204 {
205 let name = extract_rust_fn_name(line);
206 return Some((BoundaryKind::Function, name));
207 }
208
209 if line.starts_with("impl ") || line.starts_with("impl<") {
211 let name = extract_after_keyword(line, "impl");
212 return Some((BoundaryKind::Impl, name));
213 }
214
215 if line.starts_with("pub struct ")
217 || line.starts_with("struct ")
218 || line.starts_with("pub(crate) struct ")
219 {
220 let name = extract_after_keyword(line, "struct");
221 return Some((BoundaryKind::Struct, name));
222 }
223
224 if line.starts_with("pub enum ")
226 || line.starts_with("enum ")
227 || line.starts_with("pub(crate) enum ")
228 {
229 let name = extract_after_keyword(line, "enum");
230 return Some((BoundaryKind::Enum, name));
231 }
232
233 if line.starts_with("pub mod ") || line.starts_with("mod ") {
235 let name = extract_after_keyword(line, "mod");
236 return Some((BoundaryKind::Module, name));
237 }
238
239 None
240}
241
242fn detect_python_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
244 if line.starts_with("def ") || line.starts_with("async def ") {
246 let name = extract_python_fn_name(line);
247 return Some((BoundaryKind::Function, name));
248 }
249
250 if line.starts_with("class ") {
252 let name = extract_after_keyword(line, "class");
253 return Some((BoundaryKind::Class, name));
254 }
255
256 None
257}
258
259fn detect_js_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
261 if line.starts_with("function ")
263 || line.starts_with("async function ")
264 || line.starts_with("export function ")
265 || line.starts_with("export async function ")
266 {
267 let name = extract_js_fn_name(line);
268 return Some((BoundaryKind::Function, name));
269 }
270
271 if (line.starts_with("const ") || line.starts_with("export const "))
273 && (line.contains(" = (") || line.contains(" = async ("))
274 {
275 let name = extract_const_fn_name(line);
276 return Some((BoundaryKind::Function, name));
277 }
278
279 if line.starts_with("class ")
281 || line.starts_with("export class ")
282 || line.starts_with("export default class ")
283 {
284 let name = extract_after_keyword(line, "class");
285 return Some((BoundaryKind::Class, name));
286 }
287
288 if line.starts_with("interface ") || line.starts_with("export interface ") {
290 let name = extract_after_keyword(line, "interface");
291 return Some((BoundaryKind::Struct, name));
292 }
293
294 if line.starts_with("type ") || line.starts_with("export type ") {
296 let name = extract_after_keyword(line, "type");
297 return Some((BoundaryKind::Struct, name));
298 }
299
300 None
301}
302
303fn detect_go_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
305 if line.starts_with("func ") {
307 if line.contains(") ") && line.find('(') < line.find(')') {
309 let name = extract_go_method_name(line);
310 return Some((BoundaryKind::Method, name));
311 }
312 let name = extract_go_fn_name(line);
313 return Some((BoundaryKind::Function, name));
314 }
315
316 if line.starts_with("type ") && line.contains(" struct") {
318 let name = extract_after_keyword(line, "type");
319 return Some((BoundaryKind::Struct, name));
320 }
321
322 if line.starts_with("type ") && line.contains(" interface") {
324 let name = extract_after_keyword(line, "type");
325 return Some((BoundaryKind::Struct, name));
326 }
327
328 None
329}
330
331fn detect_java_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
333 if line.contains(" class ") || line.starts_with("class ") {
335 let name = extract_after_keyword(line, "class");
336 return Some((BoundaryKind::Class, name));
337 }
338
339 if line.contains(" interface ") || line.starts_with("interface ") {
341 let name = extract_after_keyword(line, "interface");
342 return Some((BoundaryKind::Struct, name));
343 }
344
345 if line.contains(" enum ") || line.starts_with("enum ") {
347 let name = extract_after_keyword(line, "enum");
348 return Some((BoundaryKind::Enum, name));
349 }
350
351 if (line.contains("public ") || line.contains("private ") || line.contains("protected "))
353 && line.contains('(')
354 && !line.contains(" class ")
355 && !line.contains(" interface ")
356 {
357 let name = extract_java_method_name(line);
358 return Some((BoundaryKind::Method, name));
359 }
360
361 None
362}
363
364fn detect_c_boundary(line: &str) -> Option<(BoundaryKind, Option<String>)> {
366 if line.starts_with("class ") || line.contains(" class ") {
368 let name = extract_after_keyword(line, "class");
369 return Some((BoundaryKind::Class, name));
370 }
371
372 if line.starts_with("struct ") || line.contains(" struct ") {
374 let name = extract_after_keyword(line, "struct");
375 return Some((BoundaryKind::Struct, name));
376 }
377
378 if line.starts_with("enum ") || line.contains(" enum ") {
380 let name = extract_after_keyword(line, "enum");
381 return Some((BoundaryKind::Enum, name));
382 }
383
384 if line.contains('(') && !line.starts_with('#') && !line.starts_with("//") {
386 let name = extract_c_fn_name(line);
388 if name.is_some() {
389 return Some((BoundaryKind::Function, name));
390 }
391 }
392
393 None
394}
395
396fn extract_rust_fn_name(line: &str) -> Option<String> {
399 let line = line
400 .trim_start_matches("pub ")
401 .trim_start_matches("pub(crate) ")
402 .trim_start_matches("pub(super) ")
403 .trim_start_matches("async ")
404 .trim_start_matches("fn ");
405 line.split('(')
406 .next()
407 .map(|s| s.split('<').next().unwrap_or(s).trim().to_string())
408 .filter(|s| !s.is_empty())
409}
410
411fn extract_python_fn_name(line: &str) -> Option<String> {
412 let line = line.trim_start_matches("async ").trim_start_matches("def ");
413 line.split('(')
414 .next()
415 .map(|s| s.trim().to_string())
416 .filter(|s| !s.is_empty())
417}
418
419fn extract_js_fn_name(line: &str) -> Option<String> {
420 let line = line
421 .trim_start_matches("export ")
422 .trim_start_matches("async ")
423 .trim_start_matches("function ");
424 line.split('(')
425 .next()
426 .map(|s| s.trim().to_string())
427 .filter(|s| !s.is_empty())
428}
429
430fn extract_const_fn_name(line: &str) -> Option<String> {
431 let line = line
432 .trim_start_matches("export ")
433 .trim_start_matches("const ");
434 line.split(" =")
435 .next()
436 .map(|s| s.trim().to_string())
437 .filter(|s| !s.is_empty())
438}
439
440fn extract_go_fn_name(line: &str) -> Option<String> {
441 let line = line.trim_start_matches("func ");
442 line.split('(')
443 .next()
444 .map(|s| s.trim().to_string())
445 .filter(|s| !s.is_empty())
446}
447
448fn extract_go_method_name(line: &str) -> Option<String> {
449 let line = line.trim_start_matches("func ");
451 if let Some(idx) = line.find(") ") {
452 let after_receiver = &line[idx + 2..];
453 return after_receiver
454 .split('(')
455 .next()
456 .map(|s| s.trim().to_string())
457 .filter(|s| !s.is_empty());
458 }
459 None
460}
461
462fn extract_java_method_name(line: &str) -> Option<String> {
463 if let Some(paren_idx) = line.find('(') {
465 let before_paren = &line[..paren_idx];
466 before_paren
468 .split_whitespace()
469 .last()
470 .map(std::string::ToString::to_string)
471 } else {
472 None
473 }
474}
475
476fn extract_c_fn_name(line: &str) -> Option<String> {
477 if let Some(paren_idx) = line.find('(') {
479 let before_paren = line[..paren_idx].trim();
480 let last_word = before_paren.split_whitespace().last()?;
482 if ["if", "while", "for", "switch", "return", "sizeof", "typeof"]
483 .contains(&last_word.to_lowercase().as_str())
484 {
485 return None;
486 }
487 let name = last_word.trim_start_matches('*').trim_start_matches('&');
489 if !name.is_empty() && name.chars().all(|c| c.is_alphanumeric() || c == '_') {
490 return Some(name.to_string());
491 }
492 }
493 None
494}
495
496fn extract_after_keyword(line: &str, keyword: &str) -> Option<String> {
497 if let Some(idx) = line.find(keyword) {
498 let after = &line[idx + keyword.len()..];
499 let after = after.trim_start();
500 let name = if after.starts_with('<') {
502 after.split('>').nth(1).unwrap_or(after).trim_start()
503 } else {
504 after
505 };
506 let name = name
508 .split(|c: char| c.is_whitespace() || c == '{' || c == '(' || c == '<' || c == ':')
509 .next()
510 .map(std::string::ToString::to_string)
511 .filter(|s| !s.is_empty());
512 return name;
513 }
514 None
515}
516
517fn create_chunks_from_boundaries(
519 text: &str,
520 lines: &[&str],
521 boundaries: &[CodeBoundary],
522 config: &ChunkConfig,
523 language: &Option<String>,
524) -> Result<Vec<ChunkOutput>, ChunkError> {
525 let mut chunks = Vec::new();
526 let chars_per_token = 4;
527 let max_chars = config.max_size * chars_per_token;
528
529 for (i, boundary) in boundaries.iter().enumerate() {
530 let start_line = boundary.line;
531 let end_line = if i + 1 < boundaries.len() {
532 boundaries[i + 1].line
533 } else {
534 lines.len()
535 };
536
537 if end_line <= start_line {
539 continue;
540 }
541
542 let chunk_lines = &lines[start_line..end_line];
544 let content: String = chunk_lines.join("\n");
545
546 if content.len() > max_chars {
548 let sub_chunks = split_large_chunk(
549 &content,
550 start_line,
551 config,
552 language,
553 &boundary.kind,
554 &boundary.name,
555 )?;
556 chunks.extend(sub_chunks);
557 } else {
558 let (byte_start, byte_end) = calculate_byte_range(text, lines, start_line, end_line);
559
560 chunks.push(ChunkOutput {
561 content,
562 byte_range: byte_start..byte_end,
563 line_range: Some(start_line as u32..end_line as u32),
564 parent_index: None,
565 depth: 0,
566 metadata: ChunkOutputMetadata {
567 symbol_type: Some(boundary.kind.as_str().to_string()),
568 symbol_name: boundary.name.clone(),
569 language: language.clone(),
570 },
571 });
572 }
573 }
574
575 if chunks.is_empty() {
577 return chunk_by_lines(text, lines, config, language);
578 }
579
580 Ok(chunks)
581}
582
583fn split_large_chunk(
585 content: &str,
586 base_line: usize,
587 config: &ChunkConfig,
588 language: &Option<String>,
589 kind: &BoundaryKind,
590 name: &Option<String>,
591) -> Result<Vec<ChunkOutput>, ChunkError> {
592 let mut chunks = Vec::new();
593 let chars_per_token = 4;
594 let target_chars = config.target_size * chars_per_token;
595 let overlap_chars = config.overlap * chars_per_token;
596 let step = target_chars.saturating_sub(overlap_chars).max(1);
597
598 let lines: Vec<&str> = content.lines().collect();
599 let mut start = 0;
600
601 while start < lines.len() {
602 let mut char_count = 0;
603 let mut end = start;
604
605 while end < lines.len() && char_count < target_chars {
607 char_count += lines[end].len() + 1; end += 1;
609 }
610
611 let chunk_content: String = lines[start..end].join("\n");
612 let chunk_lines = end - start;
613
614 let byte_start = lines[..start].iter().map(|l| l.len() + 1).sum::<usize>() as u64;
616 let byte_end = byte_start + chunk_content.len() as u64;
617
618 chunks.push(ChunkOutput {
619 content: chunk_content,
620 byte_range: byte_start..byte_end,
621 line_range: Some((base_line + start) as u32..(base_line + end) as u32),
622 parent_index: None,
623 depth: 0,
624 metadata: ChunkOutputMetadata {
625 symbol_type: Some(kind.as_str().to_string()),
626 symbol_name: name.clone(),
627 language: language.clone(),
628 },
629 });
630
631 let line_step = (step / (char_count / chunk_lines.max(1))).max(1);
633 start += line_step;
634
635 if end >= lines.len() {
636 break;
637 }
638 }
639
640 Ok(chunks)
641}
642
643fn chunk_by_lines(
645 text: &str,
646 lines: &[&str],
647 config: &ChunkConfig,
648 language: &Option<String>,
649) -> Result<Vec<ChunkOutput>, ChunkError> {
650 let mut chunks = Vec::new();
651 let chars_per_token = 4;
652 let target_chars = config.target_size * chars_per_token;
653 let overlap_lines = (config.overlap * chars_per_token) / 80; let mut start = 0;
656 while start < lines.len() {
657 let mut char_count = 0;
658 let mut end = start;
659
660 while end < lines.len() && char_count < target_chars {
661 char_count += lines[end].len() + 1;
662 end += 1;
663 }
664
665 let chunk_content: String = lines[start..end].join("\n");
666 let (byte_start, byte_end) = calculate_byte_range(text, lines, start, end);
667
668 chunks.push(ChunkOutput {
669 content: chunk_content,
670 byte_range: byte_start..byte_end,
671 line_range: Some(start as u32..end as u32),
672 parent_index: None,
673 depth: 0,
674 metadata: ChunkOutputMetadata {
675 language: language.clone(),
676 ..Default::default()
677 },
678 });
679
680 start = (end).saturating_sub(overlap_lines).max(start + 1);
681 }
682
683 Ok(chunks)
684}
685
686fn calculate_byte_range(
688 text: &str,
689 lines: &[&str],
690 start_line: usize,
691 end_line: usize,
692) -> (u64, u64) {
693 let byte_start: usize = lines[..start_line].iter().map(|l| l.len() + 1).sum();
694 let byte_end: usize = lines[..end_line].iter().map(|l| l.len() + 1).sum();
695 (byte_start as u64, byte_end.min(text.len()) as u64)
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
703 fn test_detect_rust_boundary() {
704 assert!(detect_rust_boundary("fn main() {").is_some());
705 assert!(detect_rust_boundary("pub fn new() -> Self {").is_some());
706 assert!(detect_rust_boundary("pub async fn process() {").is_some());
707 assert!(detect_rust_boundary("impl Foo {").is_some());
708 assert!(detect_rust_boundary("pub struct Bar {").is_some());
709 assert!(detect_rust_boundary("enum Baz {").is_some());
710 assert!(detect_rust_boundary("let x = 5;").is_none());
711 }
712
713 #[test]
714 fn test_detect_python_boundary() {
715 assert!(detect_python_boundary("def hello():").is_some());
716 assert!(detect_python_boundary("async def world():").is_some());
717 assert!(detect_python_boundary("class MyClass:").is_some());
718 assert!(detect_python_boundary(" x = 5").is_none());
719 }
720
721 #[test]
722 fn test_detect_js_boundary() {
723 assert!(detect_js_boundary("function foo() {").is_some());
724 assert!(detect_js_boundary("async function bar() {").is_some());
725 assert!(detect_js_boundary("export function baz() {").is_some());
726 assert!(detect_js_boundary("const fn = () => {").is_some());
727 assert!(detect_js_boundary("class Component {").is_some());
728 assert!(detect_js_boundary("export interface Props {").is_some());
729 }
730
731 #[test]
732 fn test_extract_rust_fn_name() {
733 assert_eq!(
734 extract_rust_fn_name("fn main() {"),
735 Some("main".to_string())
736 );
737 assert_eq!(
738 extract_rust_fn_name("pub fn new() -> Self {"),
739 Some("new".to_string())
740 );
741 assert_eq!(
742 extract_rust_fn_name("pub async fn process<T>() {"),
743 Some("process".to_string())
744 );
745 }
746
747 #[test]
748 fn test_language_detection() {
749 assert!(matches!(
750 Language::from_extension("rs"),
751 Some(Language::Rust)
752 ));
753 assert!(matches!(
754 Language::from_extension("py"),
755 Some(Language::Python)
756 ));
757 assert!(matches!(
758 Language::from_extension("js"),
759 Some(Language::JavaScript)
760 ));
761 assert!(Language::from_extension("unknown").is_none());
762 }
763}