1use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::debug;
10
11const DEFAULT_CACHE_SIZE: usize = 10_000;
13
14#[derive(Clone)]
16struct CacheEntry {
17 output: EmbeddingOutput,
19 access_count: u64,
21}
22
23pub struct EmbeddingCache {
25 embedder: Arc<dyn Embedder>,
27 cache: RwLock<HashMap<String, CacheEntry>>,
29 max_size: usize,
31 access_counter: RwLock<u64>,
33 stats: RwLock<CacheStats>,
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct CacheStats {
40 pub hits: u64,
42 pub misses: u64,
44 pub evictions: u64,
46}
47
48impl EmbeddingCache {
49 pub fn new(embedder: Arc<dyn Embedder>) -> Self {
51 Self::with_capacity(embedder, DEFAULT_CACHE_SIZE)
52 }
53
54 pub fn with_capacity(embedder: Arc<dyn Embedder>, max_size: usize) -> Self {
56 Self {
57 embedder,
58 cache: RwLock::new(HashMap::new()),
59 max_size,
60 access_counter: RwLock::new(0),
61 stats: RwLock::new(CacheStats::default()),
62 }
63 }
64
65 fn hash_text(text: &str) -> String {
67 let hash = blake3::hash(text.as_bytes());
68 hash.to_hex().to_string()
69 }
70
71 async fn next_access(&self) -> u64 {
73 let mut counter = self.access_counter.write().await;
74 *counter += 1;
75 *counter
76 }
77
78 async fn maybe_evict(&self) {
80 let mut cache = self.cache.write().await;
81
82 if cache.len() < self.max_size {
83 return;
84 }
85
86 let evict_count = (self.max_size / 10).max(1);
88 let mut entries: Vec<_> = cache
89 .iter()
90 .map(|(k, v)| (k.clone(), v.access_count))
91 .collect();
92 entries.sort_by_key(|(_, count)| *count);
93
94 let mut stats = self.stats.write().await;
95 for (key, _) in entries.into_iter().take(evict_count) {
96 cache.remove(&key);
97 stats.evictions += 1;
98 }
99 }
100
101 pub async fn embed_text(
103 &self,
104 texts: &[&str],
105 config: &EmbeddingConfig,
106 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
107 let mut results = Vec::with_capacity(texts.len());
108 let mut uncached_texts = Vec::new();
109 let mut uncached_indices = Vec::new();
110
111 {
113 let cache = self.cache.read().await;
114 let mut stats = self.stats.write().await;
115
116 for (i, text) in texts.iter().enumerate() {
117 let hash = Self::hash_text(text);
118 if let Some(entry) = cache.get(&hash) {
119 stats.hits += 1;
120 results.push(Some(entry.output.clone()));
121 } else {
122 stats.misses += 1;
123 uncached_texts.push(*text);
124 uncached_indices.push(i);
125 results.push(None);
126 }
127 }
128 }
129
130 if !uncached_texts.is_empty() {
132 debug!("Cache miss for {} texts, embedding", uncached_texts.len());
133
134 let new_embeddings = self.embedder.embed_text(&uncached_texts, config).await?;
135
136 self.maybe_evict().await;
138
139 let mut cache = self.cache.write().await;
140 for (text, output) in uncached_texts.iter().zip(new_embeddings.iter()) {
141 let hash = Self::hash_text(text);
142 let access = self.next_access().await;
143 cache.insert(
144 hash,
145 CacheEntry {
146 output: output.clone(),
147 access_count: access,
148 },
149 );
150 }
151
152 for (idx, output) in uncached_indices.into_iter().zip(new_embeddings) {
154 results[idx] = Some(output);
155 }
156 }
157
158 Ok(results.into_iter().flatten().collect())
160 }
161
162 pub async fn embed_query(
164 &self,
165 query: &str,
166 config: &EmbeddingConfig,
167 ) -> Result<EmbeddingOutput, EmbedError> {
168 self.embedder.embed_query(query, config).await
170 }
171
172 pub fn embedder(&self) -> Arc<dyn Embedder> {
174 Arc::clone(&self.embedder)
175 }
176
177 pub async fn stats(&self) -> CacheStats {
179 self.stats.read().await.clone()
180 }
181
182 pub async fn size(&self) -> usize {
184 self.cache.read().await.len()
185 }
186
187 pub async fn clear(&self) {
189 let mut cache = self.cache.write().await;
190 cache.clear();
191 }
192
193 pub fn dimension(&self) -> usize {
195 self.embedder.dimension()
196 }
197
198 pub fn model_name(&self) -> &str {
200 self.embedder.model_name()
201 }
202
203 pub fn modalities(&self) -> &[Modality] {
205 self.embedder.modalities()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use async_trait::async_trait;
213
214 const TEST_DIM: usize = 384;
215
216 struct MockEmbedder {
217 dimension: usize,
218 call_count: RwLock<usize>,
219 }
220
221 impl MockEmbedder {
222 fn new(dimension: usize) -> Self {
223 Self {
224 dimension,
225 call_count: RwLock::new(0),
226 }
227 }
228
229 async fn get_call_count(&self) -> usize {
230 *self.call_count.read().await
231 }
232 }
233
234 #[async_trait]
235 impl Embedder for MockEmbedder {
236 fn model_name(&self) -> &'static str {
237 "mock-embedder"
238 }
239
240 fn dimension(&self) -> usize {
241 self.dimension
242 }
243
244 fn max_tokens(&self) -> usize {
245 512
246 }
247
248 fn modalities(&self) -> &[Modality] {
249 &[Modality::Text]
250 }
251
252 async fn embed_text(
253 &self,
254 texts: &[&str],
255 _config: &EmbeddingConfig,
256 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
257 let mut count = self.call_count.write().await;
258 *count += 1;
259
260 Ok(texts
261 .iter()
262 .map(|text| {
263 let hash = blake3::hash(text.as_bytes());
264 let bytes = hash.as_bytes();
265 let embedding: Vec<f32> = (0..self.dimension)
266 .map(|i| f32::from(bytes[i % 32]) / 255.0)
267 .collect();
268 EmbeddingOutput {
269 embedding,
270 token_count: text.split_whitespace().count(),
271 }
272 })
273 .collect())
274 }
275
276 async fn embed_query(
277 &self,
278 query: &str,
279 config: &EmbeddingConfig,
280 ) -> Result<EmbeddingOutput, EmbedError> {
281 let results = self.embed_text(&[query], config).await?;
282 Ok(results.into_iter().next().unwrap())
283 }
284 }
285
286 #[tokio::test]
287 async fn test_cache_hit() {
288 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
289 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
290 let config = EmbeddingConfig::default();
291
292 let result1 = cache.embed_text(&["hello world"], &config).await.unwrap();
294 assert_eq!(embedder.get_call_count().await, 1);
295
296 let result2 = cache.embed_text(&["hello world"], &config).await.unwrap();
298 assert_eq!(embedder.get_call_count().await, 1); assert_eq!(result1[0].embedding, result2[0].embedding);
302
303 let stats = cache.stats().await;
304 assert_eq!(stats.hits, 1);
305 assert_eq!(stats.misses, 1);
306 }
307
308 #[tokio::test]
309 async fn test_cache_miss() {
310 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
311 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
312 let config = EmbeddingConfig::default();
313
314 cache.embed_text(&["text one"], &config).await.unwrap();
316 cache.embed_text(&["text two"], &config).await.unwrap();
317 cache.embed_text(&["text three"], &config).await.unwrap();
318
319 assert_eq!(embedder.get_call_count().await, 3);
320
321 let stats = cache.stats().await;
322 assert_eq!(stats.hits, 0);
323 assert_eq!(stats.misses, 3);
324 }
325
326 #[tokio::test]
327 async fn test_batch_with_mixed_cache() {
328 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
329 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
330 let config = EmbeddingConfig::default();
331
332 cache.embed_text(&["cached text"], &config).await.unwrap();
334 assert_eq!(embedder.get_call_count().await, 1);
335
336 let results = cache
338 .embed_text(&["cached text", "new text", "cached text"], &config)
339 .await
340 .unwrap();
341
342 assert_eq!(embedder.get_call_count().await, 2);
344 assert_eq!(results.len(), 3);
345
346 assert_eq!(results[0].embedding, results[2].embedding);
348 }
349
350 #[tokio::test]
351 async fn test_cache_clear() {
352 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
353 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
354 let config = EmbeddingConfig::default();
355
356 cache.embed_text(&["test"], &config).await.unwrap();
357 assert_eq!(cache.size().await, 1);
358
359 cache.clear().await;
360 assert_eq!(cache.size().await, 0);
361
362 cache.embed_text(&["test"], &config).await.unwrap();
364 assert_eq!(embedder.get_call_count().await, 2);
365 }
366
367 #[tokio::test]
368 async fn test_cache_eviction() {
369 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
370 let cache = EmbeddingCache::with_capacity(Arc::clone(&embedder) as Arc<dyn Embedder>, 10);
371 let config = EmbeddingConfig::default();
372
373 for i in 0..15 {
375 let text = format!("text number {i}");
376 cache.embed_text(&[&text], &config).await.unwrap();
377 }
378
379 assert!(cache.size().await < 15);
381
382 let stats = cache.stats().await;
383 assert!(stats.evictions > 0);
384 }
385
386 #[tokio::test]
387 async fn test_embedder_properties() {
388 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
389 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
390
391 assert_eq!(cache.dimension(), TEST_DIM);
392 assert_eq!(cache.model_name(), "mock-embedder");
393 assert_eq!(cache.modalities(), &[Modality::Text]);
394 }
395
396 #[tokio::test]
397 async fn test_cache_stats_accuracy() {
398 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
399 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
400 let config = EmbeddingConfig::default();
401
402 let stats = cache.stats().await;
404 assert_eq!(stats.hits, 0);
405 assert_eq!(stats.misses, 0);
406
407 cache.embed_text(&["test"], &config).await.unwrap();
409 let stats = cache.stats().await;
410 assert_eq!(stats.misses, 1);
411 assert_eq!(stats.hits, 0);
412
413 cache.embed_text(&["test"], &config).await.unwrap();
415 let stats = cache.stats().await;
416 assert_eq!(stats.hits, 1);
417 assert_eq!(stats.misses, 1);
418 }
419
420 #[tokio::test]
421 async fn test_cache_multiple_texts_batch() {
422 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
423 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
424 let config = EmbeddingConfig::default();
425
426 let texts = vec!["text one", "text two", "text three"];
428 let results = cache.embed_text(&texts, &config).await.unwrap();
429 assert_eq!(results.len(), 3);
430
431 let stats = cache.stats().await;
433 assert_eq!(stats.misses, 3);
434
435 cache.embed_text(&texts, &config).await.unwrap();
437 let stats = cache.stats().await;
438 assert_eq!(stats.hits, 3);
439 }
440
441 #[tokio::test]
442 async fn test_cache_empty_input() {
443 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
444 let cache = EmbeddingCache::new(Arc::clone(&embedder) as Arc<dyn Embedder>);
445 let config = EmbeddingConfig::default();
446
447 let texts: Vec<&str> = vec![];
449 let results = cache.embed_text(&texts, &config).await.unwrap();
450 assert!(results.is_empty());
451
452 let stats = cache.stats().await;
454 assert_eq!(stats.hits, 0);
455 assert_eq!(stats.misses, 0);
456 }
457}