1use ragfs_core::{
4 DistanceMetric, Embedder, EmbeddingConfig, SearchQuery, SearchResult, VectorStore,
5};
6use std::sync::Arc;
7use tracing::debug;
8
9use crate::parser::{ParsedQuery, QueryParser};
10
11pub struct QueryExecutor {
13 store: Arc<dyn VectorStore>,
15 embedder: Arc<dyn Embedder>,
17 parser: QueryParser,
19 hybrid: bool,
21}
22
23impl QueryExecutor {
24 pub fn new(
26 store: Arc<dyn VectorStore>,
27 embedder: Arc<dyn Embedder>,
28 default_limit: usize,
29 hybrid: bool,
30 ) -> Self {
31 Self {
32 store,
33 embedder,
34 parser: QueryParser::new(default_limit),
35 hybrid,
36 }
37 }
38
39 pub async fn execute(&self, query_str: &str) -> Result<Vec<SearchResult>, ragfs_core::Error> {
41 debug!("Executing query: {}", query_str);
42
43 let parsed = self.parser.parse(query_str);
45
46 let config = EmbeddingConfig::default();
48 let embedding = self
49 .embedder
50 .embed_query(&parsed.text, &config)
51 .await
52 .map_err(ragfs_core::Error::Embedding)?;
53
54 let search_query = SearchQuery {
56 embedding: embedding.embedding,
57 text: if self.hybrid {
58 Some(parsed.text.clone())
59 } else {
60 None
61 },
62 limit: parsed.limit,
63 filters: parsed.filters,
64 metric: DistanceMetric::Cosine,
65 };
66
67 let results = if self.hybrid {
69 self.store.hybrid_search(search_query).await
70 } else {
71 self.store.search(search_query).await
72 }
73 .map_err(ragfs_core::Error::Store)?;
74
75 debug!("Found {} results", results.len());
76 Ok(results)
77 }
78
79 pub async fn execute_parsed(
81 &self,
82 parsed: ParsedQuery,
83 ) -> Result<Vec<SearchResult>, ragfs_core::Error> {
84 let config = EmbeddingConfig::default();
85 let embedding = self
86 .embedder
87 .embed_query(&parsed.text, &config)
88 .await
89 .map_err(ragfs_core::Error::Embedding)?;
90
91 let search_query = SearchQuery {
92 embedding: embedding.embedding,
93 text: if self.hybrid {
94 Some(parsed.text.clone())
95 } else {
96 None
97 },
98 limit: parsed.limit,
99 filters: parsed.filters,
100 metric: DistanceMetric::Cosine,
101 };
102
103 let results = if self.hybrid {
104 self.store.hybrid_search(search_query).await
105 } else {
106 self.store.search(search_query).await
107 }
108 .map_err(ragfs_core::Error::Store)?;
109
110 Ok(results)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use async_trait::async_trait;
118 use ragfs_core::{
119 Chunk, EmbedError, EmbeddingOutput, FileRecord, Modality, StoreError, StoreStats,
120 };
121 use std::collections::HashMap;
122 use std::path::{Path, PathBuf};
123 use tokio::sync::RwLock;
124 use uuid::Uuid;
125
126 const TEST_DIM: usize = 384;
127
128 struct MockEmbedder {
131 dimension: usize,
132 }
133
134 impl MockEmbedder {
135 fn new(dimension: usize) -> Self {
136 Self { dimension }
137 }
138 }
139
140 #[async_trait]
141 impl Embedder for MockEmbedder {
142 fn model_name(&self) -> &'static str {
143 "mock-embedder"
144 }
145
146 fn dimension(&self) -> usize {
147 self.dimension
148 }
149
150 fn max_tokens(&self) -> usize {
151 512
152 }
153
154 fn modalities(&self) -> &[Modality] {
155 &[Modality::Text]
156 }
157
158 async fn embed_text(
159 &self,
160 texts: &[&str],
161 _config: &EmbeddingConfig,
162 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
163 Ok(texts
164 .iter()
165 .map(|_| EmbeddingOutput {
166 embedding: vec![0.1; self.dimension],
167 token_count: 10,
168 })
169 .collect())
170 }
171
172 async fn embed_query(
173 &self,
174 _query: &str,
175 _config: &EmbeddingConfig,
176 ) -> Result<EmbeddingOutput, EmbedError> {
177 Ok(EmbeddingOutput {
178 embedding: vec![0.1; self.dimension],
179 token_count: 10,
180 })
181 }
182 }
183
184 struct MockStore {
187 results: Arc<RwLock<Vec<SearchResult>>>,
188 hybrid_results: Arc<RwLock<Vec<SearchResult>>>,
189 }
190
191 impl MockStore {
192 fn new() -> Self {
193 Self {
194 results: Arc::new(RwLock::new(Vec::new())),
195 hybrid_results: Arc::new(RwLock::new(Vec::new())),
196 }
197 }
198
199 fn with_results(results: Vec<SearchResult>) -> Self {
200 Self {
201 results: Arc::new(RwLock::new(results)),
202 hybrid_results: Arc::new(RwLock::new(Vec::new())),
203 }
204 }
205
206 fn with_hybrid_results(results: Vec<SearchResult>, hybrid: Vec<SearchResult>) -> Self {
207 Self {
208 results: Arc::new(RwLock::new(results)),
209 hybrid_results: Arc::new(RwLock::new(hybrid)),
210 }
211 }
212 }
213
214 #[async_trait]
215 impl VectorStore for MockStore {
216 async fn init(&self) -> Result<(), StoreError> {
217 Ok(())
218 }
219
220 async fn upsert_chunks(&self, _chunks: &[Chunk]) -> Result<(), StoreError> {
221 Ok(())
222 }
223
224 async fn search(&self, _query: SearchQuery) -> Result<Vec<SearchResult>, StoreError> {
225 let results = self.results.read().await;
226 Ok(results.clone())
227 }
228
229 async fn hybrid_search(
230 &self,
231 _query: SearchQuery,
232 ) -> Result<Vec<SearchResult>, StoreError> {
233 let results = self.hybrid_results.read().await;
234 Ok(results.clone())
235 }
236
237 async fn delete_by_file_path(&self, _path: &Path) -> Result<u64, StoreError> {
238 Ok(0)
239 }
240
241 async fn get_file(&self, _path: &Path) -> Result<Option<FileRecord>, StoreError> {
242 Ok(None)
243 }
244
245 async fn upsert_file(&self, _record: &FileRecord) -> Result<(), StoreError> {
246 Ok(())
247 }
248
249 async fn stats(&self) -> Result<StoreStats, StoreError> {
250 Ok(StoreStats {
251 total_chunks: 0,
252 total_files: 0,
253 index_size_bytes: 0,
254 last_updated: None,
255 })
256 }
257
258 async fn update_file_path(&self, _from: &Path, _to: &Path) -> Result<u64, StoreError> {
259 Ok(0)
260 }
261
262 async fn get_chunks_for_file(&self, _path: &Path) -> Result<Vec<Chunk>, StoreError> {
263 Ok(vec![])
264 }
265
266 async fn get_all_chunks(&self) -> Result<Vec<Chunk>, StoreError> {
267 Ok(vec![])
268 }
269
270 async fn get_all_files(&self) -> Result<Vec<FileRecord>, StoreError> {
271 Ok(vec![])
272 }
273 }
274
275 fn create_test_result(path: &str, content: &str, score: f32) -> SearchResult {
278 SearchResult {
279 chunk_id: Uuid::new_v4(),
280 file_path: PathBuf::from(path),
281 content: content.to_string(),
282 score,
283 byte_range: 0..content.len() as u64,
284 line_range: Some(0..1),
285 metadata: HashMap::new(),
286 }
287 }
288
289 #[tokio::test]
292 async fn test_execute_simple_query() {
293 let results = vec![
294 create_test_result("/test/file1.txt", "Authentication module", 0.9),
295 create_test_result("/test/file2.txt", "Auth config", 0.8),
296 ];
297
298 let store = Arc::new(MockStore::with_results(results.clone()));
299 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
300
301 let executor = QueryExecutor::new(store, embedder, 10, false);
302
303 let query_results = executor.execute("authentication").await.unwrap();
304
305 assert_eq!(query_results.len(), 2);
306 assert_eq!(query_results[0].content, "Authentication module");
307 assert_eq!(query_results[1].content, "Auth config");
308 }
309
310 #[tokio::test]
311 async fn test_execute_with_hybrid_search() {
312 let vector_results = vec![create_test_result("/test/vector.txt", "Vector result", 0.8)];
313 let hybrid_results = vec![
314 create_test_result("/test/hybrid1.txt", "Hybrid result 1", 0.95),
315 create_test_result("/test/hybrid2.txt", "Hybrid result 2", 0.85),
316 ];
317
318 let store = Arc::new(MockStore::with_hybrid_results(
319 vector_results,
320 hybrid_results.clone(),
321 ));
322 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
323
324 let executor = QueryExecutor::new(store, embedder, 10, true);
326
327 let query_results = executor.execute("search query").await.unwrap();
328
329 assert_eq!(query_results.len(), 2);
330 assert_eq!(query_results[0].content, "Hybrid result 1");
331 }
332
333 #[tokio::test]
334 async fn test_execute_vector_only() {
335 let vector_results = vec![create_test_result(
336 "/test/vector.txt",
337 "Vector only result",
338 0.9,
339 )];
340 let hybrid_results = vec![create_test_result(
341 "/test/hybrid.txt",
342 "Hybrid result",
343 0.95,
344 )];
345
346 let store = Arc::new(MockStore::with_hybrid_results(
347 vector_results.clone(),
348 hybrid_results,
349 ));
350 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
351
352 let executor = QueryExecutor::new(store, embedder, 10, false);
354
355 let query_results = executor.execute("search query").await.unwrap();
356
357 assert_eq!(query_results.len(), 1);
358 assert_eq!(query_results[0].content, "Vector only result");
359 }
360
361 #[tokio::test]
362 async fn test_execute_empty_results() {
363 let store = Arc::new(MockStore::new());
364 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
365
366 let executor = QueryExecutor::new(store, embedder, 10, false);
367
368 let query_results = executor.execute("no results query").await.unwrap();
369
370 assert!(query_results.is_empty());
371 }
372
373 #[tokio::test]
374 async fn test_execute_with_limit_in_query() {
375 let results = vec![
376 create_test_result("/test/file1.txt", "Result 1", 0.9),
377 create_test_result("/test/file2.txt", "Result 2", 0.8),
378 create_test_result("/test/file3.txt", "Result 3", 0.7),
379 ];
380
381 let store = Arc::new(MockStore::with_results(results));
382 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
383
384 let executor = QueryExecutor::new(store, embedder, 10, false);
385
386 let query_results = executor.execute("search query limit:2").await.unwrap();
388
389 assert!(!query_results.is_empty());
392 }
393
394 #[tokio::test]
395 async fn test_execute_parsed_query() {
396 use crate::parser::ParsedQuery;
397
398 let results = vec![create_test_result("/test/file.txt", "Parsed result", 0.9)];
399
400 let store = Arc::new(MockStore::with_results(results));
401 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
402
403 let executor = QueryExecutor::new(store, embedder, 10, false);
404
405 let parsed = ParsedQuery {
406 text: "pre-parsed query".to_string(),
407 limit: 5,
408 filters: vec![],
409 };
410
411 let query_results = executor.execute_parsed(parsed).await.unwrap();
412
413 assert_eq!(query_results.len(), 1);
414 assert_eq!(query_results[0].content, "Parsed result");
415 }
416
417 #[test]
418 fn test_query_executor_creation() {
419 let store: Arc<dyn VectorStore> = Arc::new(MockStore::new());
420 let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(TEST_DIM));
421
422 let executor = QueryExecutor::new(Arc::clone(&store), Arc::clone(&embedder), 10, false);
424 assert!(!executor.hybrid);
425
426 let executor2 = QueryExecutor::new(store, embedder, 20, true);
428 assert!(executor2.hybrid);
429 }
430}