1use async_trait::async_trait;
10use chrono::Utc;
11use ragfs_core::{
12 Chunk, FileRecord, SearchQuery, SearchResult, StoreError, StoreStats, VectorStore,
13};
14use std::collections::HashMap;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use tracing::debug;
19use uuid::Uuid;
20
21pub struct MemoryStore {
44 dimension: usize,
45 chunks: Arc<RwLock<HashMap<Uuid, Chunk>>>,
46 files: Arc<RwLock<HashMap<PathBuf, FileRecord>>>,
47 initialized: Arc<RwLock<bool>>,
48}
49
50impl MemoryStore {
51 #[must_use]
53 pub fn new(dimension: usize) -> Self {
54 Self {
55 dimension,
56 chunks: Arc::new(RwLock::new(HashMap::new())),
57 files: Arc::new(RwLock::new(HashMap::new())),
58 initialized: Arc::new(RwLock::new(false)),
59 }
60 }
61
62 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
64 if a.len() != b.len() {
65 return 0.0;
66 }
67
68 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
69 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
70 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
71
72 if norm_a == 0.0 || norm_b == 0.0 {
73 return 0.0;
74 }
75
76 dot / (norm_a * norm_b)
77 }
78}
79
80impl Default for MemoryStore {
81 fn default() -> Self {
82 Self::new(384)
83 }
84}
85
86#[async_trait]
87impl VectorStore for MemoryStore {
88 async fn init(&self) -> Result<(), StoreError> {
89 let mut initialized = self.initialized.write().await;
90 *initialized = true;
91 debug!("MemoryStore initialized (dimension: {})", self.dimension);
92 Ok(())
93 }
94
95 async fn upsert_chunks(&self, chunks: &[Chunk]) -> Result<(), StoreError> {
96 let mut store = self.chunks.write().await;
97 for chunk in chunks {
98 store.insert(chunk.id, chunk.clone());
99 }
100 debug!("Upserted {} chunks", chunks.len());
101 Ok(())
102 }
103
104 async fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>, StoreError> {
105 let chunks = self.chunks.read().await;
106 let mut results: Vec<(f32, &Chunk)> = Vec::new();
107
108 for chunk in chunks.values() {
110 if let Some(embedding) = &chunk.embedding {
111 let score = Self::cosine_similarity(&query.embedding, embedding);
112 results.push((score, chunk));
113 }
114 }
115
116 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
118
119 let top_k = results
121 .into_iter()
122 .take(query.limit)
123 .map(|(score, chunk)| SearchResult {
124 chunk_id: chunk.id,
125 file_path: chunk.file_path.clone(),
126 content: chunk.content.clone(),
127 score,
128 byte_range: chunk.byte_range.clone(),
129 line_range: chunk.line_range.clone(),
130 metadata: chunk.metadata.extra.clone(),
131 })
132 .collect();
133
134 Ok(top_k)
135 }
136
137 async fn hybrid_search(&self, query: SearchQuery) -> Result<Vec<SearchResult>, StoreError> {
138 self.search(query).await
141 }
142
143 async fn delete_by_file_path(&self, path: &Path) -> Result<u64, StoreError> {
144 let mut chunks = self.chunks.write().await;
145 let mut files = self.files.write().await;
146
147 let before = chunks.len();
148 chunks.retain(|_, chunk| chunk.file_path != path);
149 let deleted = (before - chunks.len()) as u64;
150
151 files.remove(path);
152
153 debug!("Deleted {} chunks for {:?}", deleted, path);
154 Ok(deleted)
155 }
156
157 async fn update_file_path(&self, from: &Path, to: &Path) -> Result<u64, StoreError> {
158 let mut chunks = self.chunks.write().await;
159 let mut files = self.files.write().await;
160 let mut updated = 0u64;
161
162 for chunk in chunks.values_mut() {
164 if chunk.file_path == from {
165 chunk.file_path = to.to_path_buf();
166 updated += 1;
167 }
168 }
169
170 if let Some(mut record) = files.remove(from) {
172 record.path = to.to_path_buf();
173 files.insert(to.to_path_buf(), record);
174 }
175
176 debug!("Updated {} chunks from {:?} to {:?}", updated, from, to);
177 Ok(updated)
178 }
179
180 async fn get_chunks_for_file(&self, path: &Path) -> Result<Vec<Chunk>, StoreError> {
181 let chunks = self.chunks.read().await;
182 let file_chunks: Vec<Chunk> = chunks
183 .values()
184 .filter(|chunk| chunk.file_path == path)
185 .cloned()
186 .collect();
187 Ok(file_chunks)
188 }
189
190 async fn get_file(&self, path: &Path) -> Result<Option<FileRecord>, StoreError> {
191 let files = self.files.read().await;
192 Ok(files.get(path).cloned())
193 }
194
195 async fn upsert_file(&self, record: &FileRecord) -> Result<(), StoreError> {
196 let mut files = self.files.write().await;
197 files.insert(record.path.clone(), record.clone());
198 debug!("Upserted file record for {:?}", record.path);
199 Ok(())
200 }
201
202 async fn stats(&self) -> Result<StoreStats, StoreError> {
203 let chunks = self.chunks.read().await;
204 let files = self.files.read().await;
205
206 Ok(StoreStats {
207 total_chunks: chunks.len() as u64,
208 total_files: files.len() as u64,
209 index_size_bytes: 0, last_updated: Some(Utc::now()),
211 })
212 }
213
214 async fn get_all_chunks(&self) -> Result<Vec<Chunk>, StoreError> {
215 let chunks = self.chunks.read().await;
216 Ok(chunks.values().cloned().collect())
217 }
218
219 async fn get_all_files(&self) -> Result<Vec<FileRecord>, StoreError> {
220 let files = self.files.read().await;
221 Ok(files.values().cloned().collect())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use ragfs_core::{ChunkMetadata, ContentType};
229
230 fn create_test_chunk(id: Uuid, file_id: Uuid, path: &str, embedding: Vec<f32>) -> Chunk {
231 Chunk {
232 id,
233 file_id,
234 file_path: PathBuf::from(path),
235 content: "test content".to_string(),
236 content_type: ContentType::Text,
237 mime_type: Some("text/plain".to_string()),
238 chunk_index: 0,
239 byte_range: 0..12,
240 line_range: Some(0..1),
241 parent_chunk_id: None,
242 depth: 0,
243 embedding: Some(embedding),
244 metadata: ChunkMetadata::default(),
245 }
246 }
247
248 #[tokio::test]
249 async fn test_memory_store_new() {
250 let store = MemoryStore::new(384);
251 assert_eq!(store.dimension, 384);
252 }
253
254 #[tokio::test]
255 async fn test_memory_store_init() {
256 let store = MemoryStore::new(384);
257 let result = store.init().await;
258 assert!(result.is_ok());
259 }
260
261 #[tokio::test]
262 async fn test_memory_store_upsert_and_stats() {
263 let store = MemoryStore::new(3);
264 store.init().await.unwrap();
265
266 let file_id = Uuid::new_v4();
267 let chunks = vec![
268 create_test_chunk(
269 Uuid::new_v4(),
270 file_id,
271 "/test/file.txt",
272 vec![1.0, 0.0, 0.0],
273 ),
274 create_test_chunk(
275 Uuid::new_v4(),
276 file_id,
277 "/test/file.txt",
278 vec![0.0, 1.0, 0.0],
279 ),
280 ];
281
282 store.upsert_chunks(&chunks).await.unwrap();
283
284 let stats = store.stats().await.unwrap();
285 assert_eq!(stats.total_chunks, 2);
286 }
287
288 #[tokio::test]
289 async fn test_memory_store_search() {
290 let store = MemoryStore::new(3);
291 store.init().await.unwrap();
292
293 let file_id = Uuid::new_v4();
294 let chunk1_id = Uuid::new_v4();
295 let chunks = vec![
296 create_test_chunk(chunk1_id, file_id, "/test/file.txt", vec![1.0, 0.0, 0.0]),
297 create_test_chunk(
298 Uuid::new_v4(),
299 file_id,
300 "/test/file.txt",
301 vec![0.0, 1.0, 0.0],
302 ),
303 create_test_chunk(
304 Uuid::new_v4(),
305 file_id,
306 "/test/file.txt",
307 vec![0.0, 0.0, 1.0],
308 ),
309 ];
310
311 store.upsert_chunks(&chunks).await.unwrap();
312
313 let query = SearchQuery {
314 embedding: vec![1.0, 0.0, 0.0],
315 text: None,
316 limit: 2,
317 filters: vec![],
318 metric: Default::default(),
319 };
320
321 let results = store.search(query).await.unwrap();
322 assert_eq!(results.len(), 2);
323 assert_eq!(results[0].chunk_id, chunk1_id);
324 assert!((results[0].score - 1.0).abs() < 0.001);
325 }
326
327 #[tokio::test]
328 async fn test_memory_store_delete_by_file_path() {
329 let store = MemoryStore::new(3);
330 store.init().await.unwrap();
331
332 let chunks = vec![
333 create_test_chunk(
334 Uuid::new_v4(),
335 Uuid::new_v4(),
336 "/test/file1.txt",
337 vec![1.0, 0.0, 0.0],
338 ),
339 create_test_chunk(
340 Uuid::new_v4(),
341 Uuid::new_v4(),
342 "/test/file2.txt",
343 vec![0.0, 1.0, 0.0],
344 ),
345 ];
346
347 store.upsert_chunks(&chunks).await.unwrap();
348
349 let deleted = store
350 .delete_by_file_path(Path::new("/test/file1.txt"))
351 .await
352 .unwrap();
353 assert_eq!(deleted, 1);
354
355 let stats = store.stats().await.unwrap();
356 assert_eq!(stats.total_chunks, 1);
357 }
358
359 #[tokio::test]
360 async fn test_memory_store_get_all_chunks() {
361 let store = MemoryStore::new(3);
362 store.init().await.unwrap();
363
364 let file_id = Uuid::new_v4();
365 let chunks = vec![
366 create_test_chunk(
367 Uuid::new_v4(),
368 file_id,
369 "/test/file.txt",
370 vec![1.0, 0.0, 0.0],
371 ),
372 create_test_chunk(
373 Uuid::new_v4(),
374 file_id,
375 "/test/file.txt",
376 vec![0.0, 1.0, 0.0],
377 ),
378 ];
379
380 store.upsert_chunks(&chunks).await.unwrap();
381
382 let all_chunks = store.get_all_chunks().await.unwrap();
383 assert_eq!(all_chunks.len(), 2);
384 }
385
386 #[test]
387 fn test_cosine_similarity() {
388 let sim = MemoryStore::cosine_similarity(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]);
390 assert!((sim - 1.0).abs() < 0.001);
391
392 let sim = MemoryStore::cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]);
394 assert!(sim.abs() < 0.001);
395
396 let sim = MemoryStore::cosine_similarity(&[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0]);
398 assert!((sim - (-1.0)).abs() < 0.001);
399 }
400}