ragfs_embed/
pool.rs

1//! Embedder pool for concurrent embedding operations.
2
3use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
4use std::sync::Arc;
5use tokio::sync::Semaphore;
6
7/// Pool of embedders with concurrency control.
8pub struct EmbedderPool {
9    /// Primary embedder for documents
10    document_embedder: Arc<dyn Embedder>,
11    /// Semaphore to limit concurrent inference
12    semaphore: Semaphore,
13    /// Maximum concurrent operations
14    max_concurrent: usize,
15}
16
17impl EmbedderPool {
18    /// Create a new embedder pool.
19    pub fn new(embedder: Arc<dyn Embedder>, max_concurrent: usize) -> Self {
20        Self {
21            document_embedder: embedder,
22            semaphore: Semaphore::new(max_concurrent),
23            max_concurrent,
24        }
25    }
26
27    /// Get the embedding dimension.
28    pub fn dimension(&self) -> usize {
29        self.document_embedder.dimension()
30    }
31
32    /// Get the model name.
33    pub fn model_name(&self) -> &str {
34        self.document_embedder.model_name()
35    }
36
37    /// Get supported modalities.
38    pub fn modalities(&self) -> &[Modality] {
39        self.document_embedder.modalities()
40    }
41
42    /// Get the underlying embedder.
43    pub fn document_embedder(&self) -> Arc<dyn Embedder> {
44        Arc::clone(&self.document_embedder)
45    }
46
47    /// Embed a batch of texts.
48    pub async fn embed_batch(
49        &self,
50        texts: &[&str],
51        config: &EmbeddingConfig,
52    ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
53        let _permit = self
54            .semaphore
55            .acquire()
56            .await
57            .map_err(|e| EmbedError::Inference(format!("semaphore error: {e}")))?;
58
59        self.document_embedder.embed_text(texts, config).await
60    }
61
62    /// Embed a single query.
63    pub async fn embed_query(
64        &self,
65        query: &str,
66        config: &EmbeddingConfig,
67    ) -> Result<EmbeddingOutput, EmbedError> {
68        let _permit = self
69            .semaphore
70            .acquire()
71            .await
72            .map_err(|e| EmbedError::Inference(format!("semaphore error: {e}")))?;
73
74        self.document_embedder.embed_query(query, config).await
75    }
76
77    /// Get pool statistics.
78    pub fn available_permits(&self) -> usize {
79        self.semaphore.available_permits()
80    }
81
82    /// Get max concurrent operations.
83    pub fn max_concurrent(&self) -> usize {
84        self.max_concurrent
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use async_trait::async_trait;
92
93    const TEST_DIM: usize = 384;
94
95    /// Mock embedder for testing.
96    struct MockEmbedder {
97        dimension: usize,
98    }
99
100    impl MockEmbedder {
101        fn new(dimension: usize) -> Self {
102            Self { dimension }
103        }
104    }
105
106    #[async_trait]
107    impl Embedder for MockEmbedder {
108        fn model_name(&self) -> &'static str {
109            "mock-embedder"
110        }
111
112        fn dimension(&self) -> usize {
113            self.dimension
114        }
115
116        fn max_tokens(&self) -> usize {
117            512
118        }
119
120        fn modalities(&self) -> &[Modality] {
121            &[Modality::Text]
122        }
123
124        async fn embed_text(
125            &self,
126            texts: &[&str],
127            _config: &EmbeddingConfig,
128        ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
129            // Return deterministic embeddings based on text length
130            Ok(texts
131                .iter()
132                .map(|text| {
133                    let embedding: Vec<f32> = (0..self.dimension)
134                        .map(|i| ((i + text.len()) as f32 * 0.001).sin())
135                        .collect();
136                    EmbeddingOutput {
137                        embedding,
138                        token_count: text.split_whitespace().count(),
139                    }
140                })
141                .collect())
142        }
143
144        async fn embed_query(
145            &self,
146            query: &str,
147            config: &EmbeddingConfig,
148        ) -> Result<EmbeddingOutput, EmbedError> {
149            let results = self.embed_text(&[query], config).await?;
150            Ok(results.into_iter().next().unwrap())
151        }
152    }
153
154    #[tokio::test]
155    async fn test_pool_creation() {
156        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
157        let pool = EmbedderPool::new(embedder, 4);
158
159        assert_eq!(pool.dimension(), TEST_DIM);
160        assert_eq!(pool.model_name(), "mock-embedder");
161        assert_eq!(pool.max_concurrent(), 4);
162        assert_eq!(pool.available_permits(), 4);
163    }
164
165    #[tokio::test]
166    async fn test_embed_batch() {
167        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
168        let pool = EmbedderPool::new(embedder, 4);
169        let config = EmbeddingConfig::default();
170
171        let texts = vec!["hello world", "test embedding"];
172        let results = pool.embed_batch(&texts, &config).await.unwrap();
173
174        assert_eq!(results.len(), 2);
175        assert_eq!(results[0].embedding.len(), TEST_DIM);
176        assert_eq!(results[1].embedding.len(), TEST_DIM);
177    }
178
179    #[tokio::test]
180    async fn test_embed_query() {
181        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
182        let pool = EmbedderPool::new(embedder, 4);
183        let config = EmbeddingConfig::default();
184
185        let result = pool.embed_query("search query", &config).await.unwrap();
186
187        assert_eq!(result.embedding.len(), TEST_DIM);
188        assert_eq!(result.token_count, 2); // "search" and "query"
189    }
190
191    #[tokio::test]
192    async fn test_semaphore_limits_concurrency() {
193        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
194        let pool = Arc::new(EmbedderPool::new(embedder, 2));
195
196        // Initially all permits available
197        assert_eq!(pool.available_permits(), 2);
198
199        // Spawn multiple concurrent tasks
200        let pool1 = Arc::clone(&pool);
201        let pool2 = Arc::clone(&pool);
202
203        let handle1 = tokio::spawn(async move {
204            let _ = pool1
205                .embed_query("query1", &EmbeddingConfig::default())
206                .await;
207        });
208
209        let handle2 = tokio::spawn(async move {
210            let _ = pool2
211                .embed_query("query2", &EmbeddingConfig::default())
212                .await;
213        });
214
215        // Wait for both to complete
216        let _ = handle1.await;
217        let _ = handle2.await;
218
219        // All permits should be returned
220        assert_eq!(pool.available_permits(), 2);
221    }
222
223    #[tokio::test]
224    async fn test_document_embedder_access() {
225        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
226        let pool = EmbedderPool::new(embedder, 4);
227
228        let doc_embedder = pool.document_embedder();
229        assert_eq!(doc_embedder.dimension(), TEST_DIM);
230        assert_eq!(doc_embedder.model_name(), "mock-embedder");
231    }
232
233    #[tokio::test]
234    async fn test_modalities() {
235        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
236        let pool = EmbedderPool::new(embedder, 4);
237
238        let modalities = pool.modalities();
239        assert_eq!(modalities.len(), 1);
240        assert!(matches!(modalities[0], Modality::Text));
241    }
242
243    #[tokio::test]
244    async fn test_empty_batch() {
245        let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
246        let pool = EmbedderPool::new(embedder, 4);
247        let config = EmbeddingConfig::default();
248
249        let texts: Vec<&str> = vec![];
250        let results = pool.embed_batch(&texts, &config).await.unwrap();
251
252        assert!(results.is_empty());
253    }
254}