1use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
4use std::sync::Arc;
5use tokio::sync::Semaphore;
6
7pub struct EmbedderPool {
9 document_embedder: Arc<dyn Embedder>,
11 semaphore: Semaphore,
13 max_concurrent: usize,
15}
16
17impl EmbedderPool {
18 pub fn new(embedder: Arc<dyn Embedder>, max_concurrent: usize) -> Self {
20 Self {
21 document_embedder: embedder,
22 semaphore: Semaphore::new(max_concurrent),
23 max_concurrent,
24 }
25 }
26
27 pub fn dimension(&self) -> usize {
29 self.document_embedder.dimension()
30 }
31
32 pub fn model_name(&self) -> &str {
34 self.document_embedder.model_name()
35 }
36
37 pub fn modalities(&self) -> &[Modality] {
39 self.document_embedder.modalities()
40 }
41
42 pub fn document_embedder(&self) -> Arc<dyn Embedder> {
44 Arc::clone(&self.document_embedder)
45 }
46
47 pub async fn embed_batch(
49 &self,
50 texts: &[&str],
51 config: &EmbeddingConfig,
52 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
53 let _permit = self
54 .semaphore
55 .acquire()
56 .await
57 .map_err(|e| EmbedError::Inference(format!("semaphore error: {e}")))?;
58
59 self.document_embedder.embed_text(texts, config).await
60 }
61
62 pub async fn embed_query(
64 &self,
65 query: &str,
66 config: &EmbeddingConfig,
67 ) -> Result<EmbeddingOutput, EmbedError> {
68 let _permit = self
69 .semaphore
70 .acquire()
71 .await
72 .map_err(|e| EmbedError::Inference(format!("semaphore error: {e}")))?;
73
74 self.document_embedder.embed_query(query, config).await
75 }
76
77 pub fn available_permits(&self) -> usize {
79 self.semaphore.available_permits()
80 }
81
82 pub fn max_concurrent(&self) -> usize {
84 self.max_concurrent
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use async_trait::async_trait;
92
93 const TEST_DIM: usize = 384;
94
95 struct MockEmbedder {
97 dimension: usize,
98 }
99
100 impl MockEmbedder {
101 fn new(dimension: usize) -> Self {
102 Self { dimension }
103 }
104 }
105
106 #[async_trait]
107 impl Embedder for MockEmbedder {
108 fn model_name(&self) -> &'static str {
109 "mock-embedder"
110 }
111
112 fn dimension(&self) -> usize {
113 self.dimension
114 }
115
116 fn max_tokens(&self) -> usize {
117 512
118 }
119
120 fn modalities(&self) -> &[Modality] {
121 &[Modality::Text]
122 }
123
124 async fn embed_text(
125 &self,
126 texts: &[&str],
127 _config: &EmbeddingConfig,
128 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
129 Ok(texts
131 .iter()
132 .map(|text| {
133 let embedding: Vec<f32> = (0..self.dimension)
134 .map(|i| ((i + text.len()) as f32 * 0.001).sin())
135 .collect();
136 EmbeddingOutput {
137 embedding,
138 token_count: text.split_whitespace().count(),
139 }
140 })
141 .collect())
142 }
143
144 async fn embed_query(
145 &self,
146 query: &str,
147 config: &EmbeddingConfig,
148 ) -> Result<EmbeddingOutput, EmbedError> {
149 let results = self.embed_text(&[query], config).await?;
150 Ok(results.into_iter().next().unwrap())
151 }
152 }
153
154 #[tokio::test]
155 async fn test_pool_creation() {
156 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
157 let pool = EmbedderPool::new(embedder, 4);
158
159 assert_eq!(pool.dimension(), TEST_DIM);
160 assert_eq!(pool.model_name(), "mock-embedder");
161 assert_eq!(pool.max_concurrent(), 4);
162 assert_eq!(pool.available_permits(), 4);
163 }
164
165 #[tokio::test]
166 async fn test_embed_batch() {
167 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
168 let pool = EmbedderPool::new(embedder, 4);
169 let config = EmbeddingConfig::default();
170
171 let texts = vec!["hello world", "test embedding"];
172 let results = pool.embed_batch(&texts, &config).await.unwrap();
173
174 assert_eq!(results.len(), 2);
175 assert_eq!(results[0].embedding.len(), TEST_DIM);
176 assert_eq!(results[1].embedding.len(), TEST_DIM);
177 }
178
179 #[tokio::test]
180 async fn test_embed_query() {
181 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
182 let pool = EmbedderPool::new(embedder, 4);
183 let config = EmbeddingConfig::default();
184
185 let result = pool.embed_query("search query", &config).await.unwrap();
186
187 assert_eq!(result.embedding.len(), TEST_DIM);
188 assert_eq!(result.token_count, 2); }
190
191 #[tokio::test]
192 async fn test_semaphore_limits_concurrency() {
193 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
194 let pool = Arc::new(EmbedderPool::new(embedder, 2));
195
196 assert_eq!(pool.available_permits(), 2);
198
199 let pool1 = Arc::clone(&pool);
201 let pool2 = Arc::clone(&pool);
202
203 let handle1 = tokio::spawn(async move {
204 let _ = pool1
205 .embed_query("query1", &EmbeddingConfig::default())
206 .await;
207 });
208
209 let handle2 = tokio::spawn(async move {
210 let _ = pool2
211 .embed_query("query2", &EmbeddingConfig::default())
212 .await;
213 });
214
215 let _ = handle1.await;
217 let _ = handle2.await;
218
219 assert_eq!(pool.available_permits(), 2);
221 }
222
223 #[tokio::test]
224 async fn test_document_embedder_access() {
225 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
226 let pool = EmbedderPool::new(embedder, 4);
227
228 let doc_embedder = pool.document_embedder();
229 assert_eq!(doc_embedder.dimension(), TEST_DIM);
230 assert_eq!(doc_embedder.model_name(), "mock-embedder");
231 }
232
233 #[tokio::test]
234 async fn test_modalities() {
235 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
236 let pool = EmbedderPool::new(embedder, 4);
237
238 let modalities = pool.modalities();
239 assert_eq!(modalities.len(), 1);
240 assert!(matches!(modalities[0], Modality::Text));
241 }
242
243 #[tokio::test]
244 async fn test_empty_batch() {
245 let embedder = Arc::new(MockEmbedder::new(TEST_DIM));
246 let pool = EmbedderPool::new(embedder, 4);
247 let config = EmbeddingConfig::default();
248
249 let texts: Vec<&str> = vec![];
250 let results = pool.embed_batch(&texts, &config).await.unwrap();
251
252 assert!(results.is_empty());
253 }
254}