1use async_trait::async_trait;
9use candle_core::{DType, Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_transformers::models::bert::{BertModel, Config};
12use hf_hub::{Repo, RepoType, api::tokio::Api};
13use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
14use std::path::PathBuf;
15use std::sync::Arc;
16use tokenizers::Tokenizer;
17use tokio::sync::RwLock;
18use tracing::{debug, info};
19
20const MODEL_ID: &str = "thenlper/gte-small";
22
23const EMBEDDING_DIM: usize = 384;
25
26const MAX_TOKENS: usize = 512;
28
29pub struct CandleEmbedder {
31 device: Device,
33 model: Arc<RwLock<Option<BertModel>>>,
35 tokenizer: Arc<RwLock<Option<Tokenizer>>>,
37 config: Arc<RwLock<Option<Config>>>,
39 #[allow(dead_code)]
41 cache_dir: PathBuf,
42 initialized: Arc<RwLock<bool>>,
44}
45
46impl CandleEmbedder {
47 pub fn new(cache_dir: PathBuf) -> Self {
49 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
51 info!("CandleEmbedder using device: {:?}", device);
52
53 Self {
54 device,
55 model: Arc::new(RwLock::new(None)),
56 tokenizer: Arc::new(RwLock::new(None)),
57 config: Arc::new(RwLock::new(None)),
58 cache_dir,
59 initialized: Arc::new(RwLock::new(false)),
60 }
61 }
62
63 pub fn with_device(cache_dir: PathBuf, device: Device) -> Self {
65 Self {
66 device,
67 model: Arc::new(RwLock::new(None)),
68 tokenizer: Arc::new(RwLock::new(None)),
69 config: Arc::new(RwLock::new(None)),
70 cache_dir,
71 initialized: Arc::new(RwLock::new(false)),
72 }
73 }
74
75 pub async fn init(&self) -> Result<(), EmbedError> {
77 {
78 let initialized = self.initialized.read().await;
79 if *initialized {
80 return Ok(());
81 }
82 }
83
84 info!("Initializing CandleEmbedder with model: {}", MODEL_ID);
85
86 let api = Api::new()
88 .map_err(|e| EmbedError::ModelLoad(format!("Failed to create HF API: {e}")))?;
89
90 let repo = api.repo(Repo::new(MODEL_ID.to_string(), RepoType::Model));
91
92 debug!("Downloading tokenizer...");
94 let tokenizer_path = repo
95 .get("tokenizer.json")
96 .await
97 .map_err(|e| EmbedError::ModelLoad(format!("Failed to download tokenizer: {e}")))?;
98
99 debug!("Downloading config...");
101 let config_path = repo
102 .get("config.json")
103 .await
104 .map_err(|e| EmbedError::ModelLoad(format!("Failed to download config: {e}")))?;
105
106 debug!("Downloading model weights...");
108 let weights_path = repo
109 .get("model.safetensors")
110 .await
111 .map_err(|e| EmbedError::ModelLoad(format!("Failed to download weights: {e}")))?;
112
113 debug!("Loading tokenizer...");
115 let tokenizer = Tokenizer::from_file(&tokenizer_path)
116 .map_err(|e| EmbedError::ModelLoad(format!("Failed to load tokenizer: {e}")))?;
117
118 debug!("Loading config...");
120 let config_str = std::fs::read_to_string(&config_path)
121 .map_err(|e| EmbedError::ModelLoad(format!("Failed to read config: {e}")))?;
122 let config: Config = serde_json::from_str(&config_str)
123 .map_err(|e| EmbedError::ModelLoad(format!("Failed to parse config: {e}")))?;
124
125 debug!("Loading model weights...");
127 #[allow(unsafe_code)]
130 let vb = unsafe {
131 VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &self.device)
132 .map_err(|e| EmbedError::ModelLoad(format!("Failed to load weights: {e}")))?
133 };
134
135 let model = BertModel::load(vb, &config)
136 .map_err(|e| EmbedError::ModelLoad(format!("Failed to create BERT model: {e}")))?;
137
138 {
140 let mut tok = self.tokenizer.write().await;
141 *tok = Some(tokenizer);
142 }
143 {
144 let mut cfg = self.config.write().await;
145 *cfg = Some(config);
146 }
147 {
148 let mut mdl = self.model.write().await;
149 *mdl = Some(model);
150 }
151 {
152 let mut init = self.initialized.write().await;
153 *init = true;
154 }
155
156 info!("CandleEmbedder initialized successfully");
157 Ok(())
158 }
159
160 fn mean_pooling(
162 &self,
163 token_embeddings: &Tensor,
164 attention_mask: &Tensor,
165 ) -> Result<Tensor, EmbedError> {
166 let mask = attention_mask
168 .unsqueeze(2)
169 .map_err(|e| EmbedError::Inference(format!("unsqueeze failed: {e}")))?
170 .broadcast_as(token_embeddings.shape())
171 .map_err(|e| EmbedError::Inference(format!("broadcast failed: {e}")))?
172 .to_dtype(DType::F32)
173 .map_err(|e| EmbedError::Inference(format!("dtype conversion failed: {e}")))?;
174
175 let masked = token_embeddings
177 .mul(&mask)
178 .map_err(|e| EmbedError::Inference(format!("mul failed: {e}")))?;
179
180 let sum = masked
181 .sum(1)
182 .map_err(|e| EmbedError::Inference(format!("sum failed: {e}")))?;
183
184 let mask_sum = mask
186 .sum(1)
187 .map_err(|e| EmbedError::Inference(format!("mask sum failed: {e}")))?
188 .clamp(1e-9, f64::MAX)
189 .map_err(|e| EmbedError::Inference(format!("clamp failed: {e}")))?;
190
191 let mean = sum
193 .div(&mask_sum)
194 .map_err(|e| EmbedError::Inference(format!("div failed: {e}")))?;
195
196 Ok(mean)
197 }
198
199 fn normalize(&self, embeddings: &Tensor) -> Result<Tensor, EmbedError> {
201 let norm = embeddings
202 .sqr()
203 .map_err(|e| EmbedError::Inference(format!("sqr failed: {e}")))?
204 .sum_keepdim(1)
205 .map_err(|e| EmbedError::Inference(format!("sum_keepdim failed: {e}")))?
206 .sqrt()
207 .map_err(|e| EmbedError::Inference(format!("sqrt failed: {e}")))?
208 .clamp(1e-12, f64::MAX)
209 .map_err(|e| EmbedError::Inference(format!("clamp failed: {e}")))?;
210
211 let normalized = embeddings
212 .broadcast_div(&norm)
213 .map_err(|e| EmbedError::Inference(format!("div failed: {e}")))?;
214
215 Ok(normalized)
216 }
217
218 async fn encode_batch(
220 &self,
221 texts: &[&str],
222 normalize: bool,
223 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
224 self.init().await?;
226
227 let tokenizer = self.tokenizer.read().await;
228 let tokenizer = tokenizer
229 .as_ref()
230 .ok_or_else(|| EmbedError::Inference("Tokenizer not loaded".to_string()))?;
231
232 let model = self.model.read().await;
233 let model = model
234 .as_ref()
235 .ok_or_else(|| EmbedError::Inference("Model not loaded".to_string()))?;
236
237 let encodings = tokenizer
239 .encode_batch(texts.to_vec(), true)
240 .map_err(|e| EmbedError::Inference(format!("Tokenization failed: {e}")))?;
241
242 let max_len = encodings
244 .iter()
245 .map(tokenizers::Encoding::len)
246 .max()
247 .unwrap_or(0);
248 let max_len = max_len.min(MAX_TOKENS);
249
250 let mut input_ids_vec: Vec<u32> = Vec::new();
252 let mut attention_mask_vec: Vec<u32> = Vec::new();
253 let mut token_type_ids_vec: Vec<u32> = Vec::new();
254 let mut token_counts = Vec::new();
255
256 for encoding in &encodings {
257 let ids = encoding.get_ids();
258 let len = ids.len().min(max_len);
259 token_counts.push(len);
260
261 for i in 0..max_len {
263 if i < len {
264 input_ids_vec.push(ids[i]);
265 attention_mask_vec.push(1);
266 token_type_ids_vec.push(0);
267 } else {
268 input_ids_vec.push(0); attention_mask_vec.push(0);
270 token_type_ids_vec.push(0);
271 }
272 }
273 }
274
275 let batch_size = texts.len();
276
277 let input_ids = Tensor::from_vec(input_ids_vec, (batch_size, max_len), &self.device)
279 .map_err(|e| {
280 EmbedError::Inference(format!("Failed to create input_ids tensor: {e}"))
281 })?;
282
283 let attention_mask =
284 Tensor::from_vec(attention_mask_vec, (batch_size, max_len), &self.device).map_err(
285 |e| EmbedError::Inference(format!("Failed to create attention_mask tensor: {e}")),
286 )?;
287
288 let token_type_ids =
289 Tensor::from_vec(token_type_ids_vec, (batch_size, max_len), &self.device).map_err(
290 |e| EmbedError::Inference(format!("Failed to create token_type_ids tensor: {e}")),
291 )?;
292
293 let output = model
295 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
296 .map_err(|e| EmbedError::Inference(format!("Model forward failed: {e}")))?;
297
298 let pooled = self.mean_pooling(&output, &attention_mask)?;
300
301 let final_embeddings = if normalize {
303 self.normalize(&pooled)?
304 } else {
305 pooled
306 };
307
308 let mut results = Vec::with_capacity(batch_size);
310
311 for i in 0..batch_size {
312 let embedding = final_embeddings
313 .get(i)
314 .map_err(|e| EmbedError::Inference(format!("Failed to get embedding {i}: {e}")))?
315 .to_vec1::<f32>()
316 .map_err(|e| EmbedError::Inference(format!("Failed to convert to vec: {e}")))?;
317
318 results.push(EmbeddingOutput {
319 embedding,
320 token_count: token_counts[i],
321 });
322 }
323
324 Ok(results)
325 }
326}
327
328#[async_trait]
329impl Embedder for CandleEmbedder {
330 fn model_name(&self) -> &str {
331 MODEL_ID
332 }
333
334 fn dimension(&self) -> usize {
335 EMBEDDING_DIM
336 }
337
338 fn max_tokens(&self) -> usize {
339 MAX_TOKENS
340 }
341
342 fn modalities(&self) -> &[Modality] {
343 &[Modality::Text]
344 }
345
346 async fn embed_text(
347 &self,
348 texts: &[&str],
349 config: &EmbeddingConfig,
350 ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
351 if texts.is_empty() {
352 return Ok(Vec::new());
353 }
354
355 debug!(
356 "Embedding {} texts with batch_size {}",
357 texts.len(),
358 config.batch_size
359 );
360
361 let mut all_results = Vec::with_capacity(texts.len());
363
364 for chunk in texts.chunks(config.batch_size) {
365 let batch_results = self.encode_batch(chunk, config.normalize).await?;
366 all_results.extend(batch_results);
367 }
368
369 Ok(all_results)
370 }
371
372 async fn embed_query(
373 &self,
374 query: &str,
375 config: &EmbeddingConfig,
376 ) -> Result<EmbeddingOutput, EmbedError> {
377 let results = self.embed_text(&[query], config).await?;
380 results
381 .into_iter()
382 .next()
383 .ok_or_else(|| EmbedError::Inference("Empty embedding result".to_string()))
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use tempfile::tempdir;
391
392 #[tokio::test]
393 #[ignore] async fn test_candle_embedder() {
395 let cache_dir = tempdir().unwrap();
396 let embedder = CandleEmbedder::new(cache_dir.path().to_path_buf());
397
398 embedder.init().await.unwrap();
399
400 assert_eq!(embedder.dimension(), 384);
401 assert_eq!(embedder.model_name(), "thenlper/gte-small");
402
403 let config = EmbeddingConfig::default();
404 let texts = &["Hello world", "This is a test"];
405
406 let results = embedder.embed_text(texts, &config).await.unwrap();
407 assert_eq!(results.len(), 2);
408 assert_eq!(results[0].embedding.len(), 384);
409 assert_eq!(results[1].embedding.len(), 384);
410
411 let norm: f32 = results[0]
413 .embedding
414 .iter()
415 .map(|x| x * x)
416 .sum::<f32>()
417 .sqrt();
418 assert!((norm - 1.0).abs() < 0.01);
419 }
420}