ragfs_embed/
cache.rs

1//! Embedding cache for avoiding redundant computations.
2//!
3//! This module provides a simple LRU cache for embeddings based on content hashes.
4
5use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::debug;
10
11/// Maximum number of entries in the cache.
12const DEFAULT_CACHE_SIZE: usize = 10_000;
13
14/// A cached embedding entry.
15#[derive(Clone)]
16struct CacheEntry {
17    /// The embedding output
18    output: EmbeddingOutput,
19    /// Access counter for LRU eviction
20    access_count: u64,
21}
22
23/// Embedding cache with LRU eviction.
24pub struct EmbeddingCache {
25    /// The underlying embedder
26    embedder: Arc<dyn Embedder>,
27    /// Cache map: content hash -> embedding
28    cache: RwLock<HashMap<String, CacheEntry>>,
29    /// Maximum cache size
30    max_size: usize,
31    /// Global access counter
32    access_counter: RwLock<u64>,
33    /// Cache statistics
34    stats: RwLock<CacheStats>,
35}
36
37/// Cache statistics.
38#[derive(Debug, Clone, Default)]
39pub struct CacheStats {
40    /// Number of cache hits
41    pub hits: u64,
42    /// Number of cache misses
43    pub misses: u64,
44    /// Number of entries evicted
45    pub evictions: u64,
46}
47
48impl EmbeddingCache {
49    /// Create a new embedding cache with default size.
50    pub fn new(embedder: Arc<dyn Embedder>) -> Self {
51        Self::with_capacity(embedder, DEFAULT_CACHE_SIZE)
52    }
53
54    /// Create a new embedding cache with specified capacity.
55    pub fn with_capacity(embedder: Arc<dyn Embedder>, max_size: usize) -> Self {
56        Self {
57            embedder,
58            cache: RwLock::new(HashMap::new()),
59            max_size,
60            access_counter: RwLock::new(0),
61            stats: RwLock::new(CacheStats::default()),
62        }
63    }
64
65    /// Compute hash for a text.
66    fn hash_text(text: &str) -> String {
67        let hash = blake3::hash(text.as_bytes());
68        hash.to_hex().to_string()
69    }
70
71    /// Get the next access count.
72    async fn next_access(&self) -> u64 {
73        let mut counter = self.access_counter.write().await;
74        *counter += 1;
75        *counter
76    }
77
78    /// Evict oldest entries if cache is full.
79    async fn maybe_evict(&self) {
80        let mut cache = self.cache.write().await;
81
82        if cache.len() < self.max_size {
83            return;
84        }
85
86        // Find entries to evict (oldest 10%)
87        let evict_count = (self.max_size / 10).max(1);
88        let mut entries: Vec<_> = cache
89            .iter()
90            .map(|(k, v)| (k.clone(), v.access_count))
91            .collect();
92        entries.sort_by_key(|(_, count)| *count);
93
94        let mut stats = self.stats.write().await;
95        for (key, _) in entries.into_iter().take(evict_count) {
96            cache.remove(&key);
97            stats.evictions += 1;
98        }
99    }
100
101    /// Embed texts with caching.
102    pub async fn embed_text(
103        &self,
104        texts: &[&str],
105        config: &EmbeddingConfig,
106    ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
107        let mut results = Vec::with_capacity(texts.len());
108        let mut uncached_texts = Vec::new();
109        let mut uncached_indices = Vec::new();
110
111        // Check cache for each text
112        {
113            let cache = self.cache.read().await;
114            let mut stats = self.stats.write().await;
115
116            for (i, text) in texts.iter().enumerate() {
117                let hash = Self::hash_text(text);
118                if let Some(entry) = cache.get(&hash) {
119                    stats.hits += 1;
120                    results.push(Some(entry.output.clone()));
121                } else {
122                    stats.misses += 1;
123                    uncached_texts.push(*text);
124                    uncached_indices.push(i);
125                    results.push(None);
126                }
127            }
128        }
129
130        // Embed uncached texts
131        if !uncached_texts.is_empty() {
132            debug!("Cache miss for {} texts, embedding", uncached_texts.len());
133
134            let new_embeddings = self.embedder.embed_text(&uncached_texts, config).await?;
135
136            // Update cache
137            self.maybe_evict().await;
138
139            let mut cache = self.cache.write().await;
140            for (text, output) in uncached_texts.iter().zip(new_embeddings.iter()) {
141                let hash = Self::hash_text(text);
142                let access = self.next_access().await;
143                cache.insert(
144                    hash,
145                    CacheEntry {
146                        output: output.clone(),
147                        access_count: access,
148                    },
149                );
150            }
151
152            // Fill in results
153            for (idx, output) in uncached_indices.into_iter().zip(new_embeddings) {
154                results[idx] = Some(output);
155            }
156        }
157
158        // Collect all results (all should be Some now)
159        Ok(results.into_iter().flatten().collect())
160    }
161
162    /// Embed a single query (always bypasses cache for queries).
163    pub async fn embed_query(
164        &self,
165        query: &str,
166        config: &EmbeddingConfig,
167    ) -> Result<EmbeddingOutput, EmbedError> {
168        // Queries typically shouldn't be cached as they're one-off
169        self.embedder.embed_query(query, config).await
170    }
171
172    /// Get the underlying embedder.
173    pub fn embedder(&self) -> Arc<dyn Embedder> {
174        Arc::clone(&self.embedder)
175    }
176
177    /// Get cache statistics.
178    pub async fn stats(&self) -> CacheStats {
179        self.stats.read().await.clone()
180    }
181
182    /// Get cache size.
183    pub async fn size(&self) -> usize {
184        self.cache.read().await.len()
185    }
186
187    /// Clear the cache.
188    pub async fn clear(&self) {
189        let mut cache = self.cache.write().await;
190        cache.clear();
191    }
192
193    /// Get the embedding dimension.
194    pub fn dimension(&self) -> usize {
195        self.embedder.dimension()
196    }
197
198    /// Get the model name.
199    pub fn model_name(&self) -> &str {
200        self.embedder.model_name()
201    }
202
203    /// Get supported modalities.
204    pub fn modalities(&self) -> &[Modality] {
205        self.embedder.modalities()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use async_trait::async_trait;
213
214    const TEST_DIM: usize = 384;
215
216    struct MockEmbedder {
217        dimension: usize,
218        call_count: RwLock<usize>,
219    }
220
221    impl MockEmbedder {
222        fn new(dimension: usize) -> Self {
223            Self {
224                dimension,
225                call_count: RwLock::new(0),
226            }
227        }
228
229        async fn get_call_count(&self) -> usize {
230            *self.call_count.read().await
231        }
232    }
233
234    #[async_trait]
235    impl Embedder for MockEmbedder {
236        fn model_name(&self) -> &'static str {
237            "mock-embedder"
238        }
239
240        fn dimension(&self) -> usize {
241            self.dimension
242        }
243
244        fn max_tokens(&self) -> usize {
245            512
246        }
247
248        fn modalities(&self) -> &[Modality] {
249            &[Modality::Text]
250        }
251
252        async fn embed_text(
253            &self,
254            texts: &[&str],
255            _config: &EmbeddingConfig,
256        ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
257            let mut count = self.call_count.write().await;
258            *count += 1;
259
260            Ok(texts
261                .iter()
262                .map(|text| {
263                    let hash = blake3::hash(text.as_bytes());
264                    let bytes = hash.as_bytes();
265                    let embedding: Vec<f32> = (0..self.dimension)
266                        .map(|i| f32::from(bytes[i % 32]) / 255.0)
267                        .collect();
268                    EmbeddingOutput {
269                        embedding,
270                        token_count: text.split_whitespace().count(),
271                    }
272                })
273                .collect())
274        }
275
276        async fn embed_query(
277            &self,
278            query: &str,
279            config: &EmbeddingConfig,
280        ) -> Result<EmbeddingOutput, EmbedError> {
281            let results = self.embed_text(&[query], config).await?;
282            Ok(results.into_iter().next().unwrap())
283        }
284    }
285
286    #[tokio::test]
287    async fn test_cache_hit() {
288        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
289        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
290        let config = EmbeddingConfig::default();
291
292        // First call - cache miss
293        let result1 = cache.embed_text(&["hello world"], &config).await.unwrap();
294        assert_eq!(embedder.get_call_count().await, 1);
295
296        // Second call - cache hit
297        let result2 = cache.embed_text(&["hello world"], &config).await.unwrap();
298        assert_eq!(embedder.get_call_count().await, 1); // No additional call
299
300        // Results should be identical
301        assert_eq!(result1[0].embedding, result2[0].embedding);
302
303        let stats = cache.stats().await;
304        assert_eq!(stats.hits, 1);
305        assert_eq!(stats.misses, 1);
306    }
307
308    #[tokio::test]
309    async fn test_cache_miss() {
310        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
311        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
312        let config = EmbeddingConfig::default();
313
314        // Different texts - all misses
315        cache.embed_text(&["text one"], &config).await.unwrap();
316        cache.embed_text(&["text two"], &config).await.unwrap();
317        cache.embed_text(&["text three"], &config).await.unwrap();
318
319        assert_eq!(embedder.get_call_count().await, 3);
320
321        let stats = cache.stats().await;
322        assert_eq!(stats.hits, 0);
323        assert_eq!(stats.misses, 3);
324    }
325
326    #[tokio::test]
327    async fn test_batch_with_mixed_cache() {
328        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
329        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
330        let config = EmbeddingConfig::default();
331
332        // Prime cache with one text
333        cache.embed_text(&["cached text"], &config).await.unwrap();
334        assert_eq!(embedder.get_call_count().await, 1);
335
336        // Batch with mix of cached and uncached
337        let results = cache
338            .embed_text(&["cached text", "new text", "cached text"], &config)
339            .await
340            .unwrap();
341
342        // Only one new embedding call for "new text"
343        assert_eq!(embedder.get_call_count().await, 2);
344        assert_eq!(results.len(), 3);
345
346        // First and third results should be identical
347        assert_eq!(results[0].embedding, results[2].embedding);
348    }
349
350    #[tokio::test]
351    async fn test_cache_clear() {
352        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
353        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
354        let config = EmbeddingConfig::default();
355
356        cache.embed_text(&["test"], &config).await.unwrap();
357        assert_eq!(cache.size().await, 1);
358
359        cache.clear().await;
360        assert_eq!(cache.size().await, 0);
361
362        // Should be a miss now
363        cache.embed_text(&["test"], &config).await.unwrap();
364        assert_eq!(embedder.get_call_count().await, 2);
365    }
366
367    #[tokio::test]
368    async fn test_cache_eviction() {
369        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
370        let cache = EmbeddingCache::with_capacity(Arc::clone(&embedder) as Arc<dyn Embedder>, 10);
371        let config = EmbeddingConfig::default();
372
373        // Fill cache beyond capacity
374        for i in 0..15 {
375            let text = format!("text number {i}");
376            cache.embed_text(&[&text], &config).await.unwrap();
377        }
378
379        // Cache should have evicted some entries
380        assert!(cache.size().await < 15);
381
382        let stats = cache.stats().await;
383        assert!(stats.evictions > 0);
384    }
385
386    #[tokio::test]
387    async fn test_embedder_properties() {
388        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
389        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
390
391        assert_eq!(cache.dimension(), TEST_DIM);
392        assert_eq!(cache.model_name(), "mock-embedder");
393        assert_eq!(cache.modalities(), &[Modality::Text]);
394    }
395
396    #[tokio::test]
397    async fn test_cache_stats_accuracy() {
398        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
399        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
400        let config = EmbeddingConfig::default();
401
402        // Initial stats should be zero
403        let stats = cache.stats().await;
404        assert_eq!(stats.hits, 0);
405        assert_eq!(stats.misses, 0);
406
407        // First call: miss
408        cache.embed_text(&["test"], &config).await.unwrap();
409        let stats = cache.stats().await;
410        assert_eq!(stats.misses, 1);
411        assert_eq!(stats.hits, 0);
412
413        // Second call with same text: hit
414        cache.embed_text(&["test"], &config).await.unwrap();
415        let stats = cache.stats().await;
416        assert_eq!(stats.hits, 1);
417        assert_eq!(stats.misses, 1);
418    }
419
420    #[tokio::test]
421    async fn test_cache_multiple_texts_batch() {
422        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
423        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
424        let config = EmbeddingConfig::default();
425
426        // Embed multiple texts in one call
427        let texts = vec!["text one", "text two", "text three"];
428        let results = cache.embed_text(&texts, &config).await.unwrap();
429        assert_eq!(results.len(), 3);
430
431        // All should now be cached
432        let stats = cache.stats().await;
433        assert_eq!(stats.misses, 3);
434
435        // Call again - should all be hits
436        cache.embed_text(&texts, &config).await.unwrap();
437        let stats = cache.stats().await;
438        assert_eq!(stats.hits, 3);
439    }
440
441    #[tokio::test]
442    async fn test_cache_empty_input() {
443        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
444        let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
445        let config = EmbeddingConfig::default();
446
447        // Empty input should return empty results
448        let texts: Vec<&str> = vec![];
449        let results = cache.embed_text(&texts, &config).await.unwrap();
450        assert!(results.is_empty());
451
452        // Stats should be unchanged
453        let stats = cache.stats().await;
454        assert_eq!(stats.hits, 0);
455        assert_eq!(stats.misses, 0);
456    }
457}