ragfs_embed/
candle.rs

1//! GTE-small embedder using Candle.
2//!
3//! Uses thenlper/gte-small model for text embeddings:
4//! - 384 dimensions
5//! - 512 max tokens
6//! - BERT architecture
7
8use async_trait::async_trait;
9use candle_core::{DType, Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_transformers::models::bert::{BertModel, Config};
12use hf_hub::{Repo, RepoType, api::tokio::Api};
13use ragfs_core::{EmbedError, Embedder, EmbeddingConfig, EmbeddingOutput, Modality};
14use std::path::PathBuf;
15use std::sync::Arc;
16use tokenizers::Tokenizer;
17use tokio::sync::RwLock;
18use tracing::{debug, info};
19
20/// Model identifier on `HuggingFace` Hub.
21const MODEL_ID: &str = "thenlper/gte-small";
22
23/// Embedding dimension for gte-small.
24const EMBEDDING_DIM: usize = 384;
25
26/// Maximum sequence length.
27const MAX_TOKENS: usize = 512;
28
29/// GTE-small embedder using Candle.
30pub struct CandleEmbedder {
31    /// Device to run inference on (CPU or CUDA)
32    device: Device,
33    /// Loaded model
34    model: Arc<RwLock<Option<BertModel>>>,
35    /// Tokenizer
36    tokenizer: Arc<RwLock<Option<Tokenizer>>>,
37    /// Model configuration
38    config: Arc<RwLock<Option<Config>>>,
39    /// Cache directory for models
40    #[allow(dead_code)]
41    cache_dir: PathBuf,
42    /// Whether model is initialized
43    initialized: Arc<RwLock<bool>>,
44}
45
46impl CandleEmbedder {
47    /// Create a new `CandleEmbedder`.
48    pub fn new(cache_dir: PathBuf) -> Self {
49        // Try to use CUDA if available, fallback to CPU
50        let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
51        info!("CandleEmbedder using device: {:?}", device);
52
53        Self {
54            device,
55            model: Arc::new(RwLock::new(None)),
56            tokenizer: Arc::new(RwLock::new(None)),
57            config: Arc::new(RwLock::new(None)),
58            cache_dir,
59            initialized: Arc::new(RwLock::new(false)),
60        }
61    }
62
63    /// Create with specific device.
64    pub fn with_device(cache_dir: PathBuf, device: Device) -> Self {
65        Self {
66            device,
67            model: Arc::new(RwLock::new(None)),
68            tokenizer: Arc::new(RwLock::new(None)),
69            config: Arc::new(RwLock::new(None)),
70            cache_dir,
71            initialized: Arc::new(RwLock::new(false)),
72        }
73    }
74
75    /// Initialize the model (download if needed, load into memory).
76    pub async fn init(&self) -> Result<(), EmbedError> {
77        {
78            let initialized = self.initialized.read().await;
79            if *initialized {
80                return Ok(());
81            }
82        }
83
84        info!("Initializing CandleEmbedder with model: {}", MODEL_ID);
85
86        // Download model files from HuggingFace Hub
87        let api = Api::new()
88            .map_err(|e| EmbedError::ModelLoad(format!("Failed to create HF API: {e}")))?;
89
90        let repo = api.repo(Repo::new(MODEL_ID.to_string(), RepoType::Model));
91
92        // Download tokenizer
93        debug!("Downloading tokenizer...");
94        let tokenizer_path = repo
95            .get("tokenizer.json")
96            .await
97            .map_err(|e| EmbedError::ModelLoad(format!("Failed to download tokenizer: {e}")))?;
98
99        // Download model config
100        debug!("Downloading config...");
101        let config_path = repo
102            .get("config.json")
103            .await
104            .map_err(|e| EmbedError::ModelLoad(format!("Failed to download config: {e}")))?;
105
106        // Download model weights
107        debug!("Downloading model weights...");
108        let weights_path = repo
109            .get("model.safetensors")
110            .await
111            .map_err(|e| EmbedError::ModelLoad(format!("Failed to download weights: {e}")))?;
112
113        // Load tokenizer
114        debug!("Loading tokenizer...");
115        let tokenizer = Tokenizer::from_file(&tokenizer_path)
116            .map_err(|e| EmbedError::ModelLoad(format!("Failed to load tokenizer: {e}")))?;
117
118        // Load config
119        debug!("Loading config...");
120        let config_str = std::fs::read_to_string(&config_path)
121            .map_err(|e| EmbedError::ModelLoad(format!("Failed to read config: {e}")))?;
122        let config: Config = serde_json::from_str(&config_str)
123            .map_err(|e| EmbedError::ModelLoad(format!("Failed to parse config: {e}")))?;
124
125        // Load model weights
126        debug!("Loading model weights...");
127        // SAFETY: The safetensors file is downloaded from HuggingFace Hub and is trusted.
128        // Memory mapping is safe for read-only access to model weights.
129        #[allow(unsafe_code)]
130        let vb = unsafe {
131            VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &self.device)
132                .map_err(|e| EmbedError::ModelLoad(format!("Failed to load weights: {e}")))?
133        };
134
135        let model = BertModel::load(vb, &config)
136            .map_err(|e| EmbedError::ModelLoad(format!("Failed to create BERT model: {e}")))?;
137
138        // Store in instance
139        {
140            let mut tok = self.tokenizer.write().await;
141            *tok = Some(tokenizer);
142        }
143        {
144            let mut cfg = self.config.write().await;
145            *cfg = Some(config);
146        }
147        {
148            let mut mdl = self.model.write().await;
149            *mdl = Some(model);
150        }
151        {
152            let mut init = self.initialized.write().await;
153            *init = true;
154        }
155
156        info!("CandleEmbedder initialized successfully");
157        Ok(())
158    }
159
160    /// Mean pooling with attention mask.
161    fn mean_pooling(
162        &self,
163        token_embeddings: &Tensor,
164        attention_mask: &Tensor,
165    ) -> Result<Tensor, EmbedError> {
166        // Expand attention mask to match embedding dimensions
167        let mask = attention_mask
168            .unsqueeze(2)
169            .map_err(|e| EmbedError::Inference(format!("unsqueeze failed: {e}")))?
170            .broadcast_as(token_embeddings.shape())
171            .map_err(|e| EmbedError::Inference(format!("broadcast failed: {e}")))?
172            .to_dtype(DType::F32)
173            .map_err(|e| EmbedError::Inference(format!("dtype conversion failed: {e}")))?;
174
175        // Masked sum
176        let masked = token_embeddings
177            .mul(&mask)
178            .map_err(|e| EmbedError::Inference(format!("mul failed: {e}")))?;
179
180        let sum = masked
181            .sum(1)
182            .map_err(|e| EmbedError::Inference(format!("sum failed: {e}")))?;
183
184        // Count non-masked tokens
185        let mask_sum = mask
186            .sum(1)
187            .map_err(|e| EmbedError::Inference(format!("mask sum failed: {e}")))?
188            .clamp(1e-9, f64::MAX)
189            .map_err(|e| EmbedError::Inference(format!("clamp failed: {e}")))?;
190
191        // Mean
192        let mean = sum
193            .div(&mask_sum)
194            .map_err(|e| EmbedError::Inference(format!("div failed: {e}")))?;
195
196        Ok(mean)
197    }
198
199    /// L2 normalize embeddings.
200    fn normalize(&self, embeddings: &Tensor) -> Result<Tensor, EmbedError> {
201        let norm = embeddings
202            .sqr()
203            .map_err(|e| EmbedError::Inference(format!("sqr failed: {e}")))?
204            .sum_keepdim(1)
205            .map_err(|e| EmbedError::Inference(format!("sum_keepdim failed: {e}")))?
206            .sqrt()
207            .map_err(|e| EmbedError::Inference(format!("sqrt failed: {e}")))?
208            .clamp(1e-12, f64::MAX)
209            .map_err(|e| EmbedError::Inference(format!("clamp failed: {e}")))?;
210
211        let normalized = embeddings
212            .broadcast_div(&norm)
213            .map_err(|e| EmbedError::Inference(format!("div failed: {e}")))?;
214
215        Ok(normalized)
216    }
217
218    /// Encode a batch of texts.
219    async fn encode_batch(
220        &self,
221        texts: &[&str],
222        normalize: bool,
223    ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
224        // Ensure initialized
225        self.init().await?;
226
227        let tokenizer = self.tokenizer.read().await;
228        let tokenizer = tokenizer
229            .as_ref()
230            .ok_or_else(|| EmbedError::Inference("Tokenizer not loaded".to_string()))?;
231
232        let model = self.model.read().await;
233        let model = model
234            .as_ref()
235            .ok_or_else(|| EmbedError::Inference("Model not loaded".to_string()))?;
236
237        // Tokenize all texts
238        let encodings = tokenizer
239            .encode_batch(texts.to_vec(), true)
240            .map_err(|e| EmbedError::Inference(format!("Tokenization failed: {e}")))?;
241
242        // Find max length for padding
243        let max_len = encodings
244            .iter()
245            .map(tokenizers::Encoding::len)
246            .max()
247            .unwrap_or(0);
248        let max_len = max_len.min(MAX_TOKENS);
249
250        // Prepare input tensors
251        let mut input_ids_vec: Vec<u32> = Vec::new();
252        let mut attention_mask_vec: Vec<u32> = Vec::new();
253        let mut token_type_ids_vec: Vec<u32> = Vec::new();
254        let mut token_counts = Vec::new();
255
256        for encoding in &encodings {
257            let ids = encoding.get_ids();
258            let len = ids.len().min(max_len);
259            token_counts.push(len);
260
261            // Add IDs with padding
262            for i in 0..max_len {
263                if i < len {
264                    input_ids_vec.push(ids[i]);
265                    attention_mask_vec.push(1);
266                    token_type_ids_vec.push(0);
267                } else {
268                    input_ids_vec.push(0); // PAD token
269                    attention_mask_vec.push(0);
270                    token_type_ids_vec.push(0);
271                }
272            }
273        }
274
275        let batch_size = texts.len();
276
277        // Create tensors
278        let input_ids = Tensor::from_vec(input_ids_vec, (batch_size, max_len), &self.device)
279            .map_err(|e| {
280                EmbedError::Inference(format!("Failed to create input_ids tensor: {e}"))
281            })?;
282
283        let attention_mask =
284            Tensor::from_vec(attention_mask_vec, (batch_size, max_len), &self.device).map_err(
285                |e| EmbedError::Inference(format!("Failed to create attention_mask tensor: {e}")),
286            )?;
287
288        let token_type_ids =
289            Tensor::from_vec(token_type_ids_vec, (batch_size, max_len), &self.device).map_err(
290                |e| EmbedError::Inference(format!("Failed to create token_type_ids tensor: {e}")),
291            )?;
292
293        // Run model
294        let output = model
295            .forward(&input_ids, &token_type_ids, Some(&attention_mask))
296            .map_err(|e| EmbedError::Inference(format!("Model forward failed: {e}")))?;
297
298        // Mean pooling
299        let pooled = self.mean_pooling(&output, &attention_mask)?;
300
301        // Normalize if requested
302        let final_embeddings = if normalize {
303            self.normalize(&pooled)?
304        } else {
305            pooled
306        };
307
308        // Convert to Vec<EmbeddingOutput>
309        let mut results = Vec::with_capacity(batch_size);
310
311        for i in 0..batch_size {
312            let embedding = final_embeddings
313                .get(i)
314                .map_err(|e| EmbedError::Inference(format!("Failed to get embedding {i}: {e}")))?
315                .to_vec1::<f32>()
316                .map_err(|e| EmbedError::Inference(format!("Failed to convert to vec: {e}")))?;
317
318            results.push(EmbeddingOutput {
319                embedding,
320                token_count: token_counts[i],
321            });
322        }
323
324        Ok(results)
325    }
326}
327
328#[async_trait]
329impl Embedder for CandleEmbedder {
330    fn model_name(&self) -> &str {
331        MODEL_ID
332    }
333
334    fn dimension(&self) -> usize {
335        EMBEDDING_DIM
336    }
337
338    fn max_tokens(&self) -> usize {
339        MAX_TOKENS
340    }
341
342    fn modalities(&self) -> &[Modality] {
343        &[Modality::Text]
344    }
345
346    async fn embed_text(
347        &self,
348        texts: &[&str],
349        config: &EmbeddingConfig,
350    ) -> Result<Vec<EmbeddingOutput>, EmbedError> {
351        if texts.is_empty() {
352            return Ok(Vec::new());
353        }
354
355        debug!(
356            "Embedding {} texts with batch_size {}",
357            texts.len(),
358            config.batch_size
359        );
360
361        // Process in batches
362        let mut all_results = Vec::with_capacity(texts.len());
363
364        for chunk in texts.chunks(config.batch_size) {
365            let batch_results = self.encode_batch(chunk, config.normalize).await?;
366            all_results.extend(batch_results);
367        }
368
369        Ok(all_results)
370    }
371
372    async fn embed_query(
373        &self,
374        query: &str,
375        config: &EmbeddingConfig,
376    ) -> Result<EmbeddingOutput, EmbedError> {
377        // For GTE models, queries and documents use the same embedding process
378        // Some models use different prefixes, but GTE doesn't need that
379        let results = self.embed_text(&[query], config).await?;
380        results
381            .into_iter()
382            .next()
383            .ok_or_else(|| EmbedError::Inference("Empty embedding result".to_string()))
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use tempfile::tempdir;
391
392    #[tokio::test]
393    #[ignore] // Requires model download
394    async fn test_candle_embedder() {
395        let cache_dir = tempdir().unwrap();
396        let embedder = CandleEmbedder::new(cache_dir.path().to_path_buf());
397
398        embedder.init().await.unwrap();
399
400        assert_eq!(embedder.dimension(), 384);
401        assert_eq!(embedder.model_name(), "thenlper/gte-small");
402
403        let config = EmbeddingConfig::default();
404        let texts = &["Hello world", "This is a test"];
405
406        let results = embedder.embed_text(texts, &config).await.unwrap();
407        assert_eq!(results.len(), 2);
408        assert_eq!(results[0].embedding.len(), 384);
409        assert_eq!(results[1].embedding.len(), 384);
410
411        // Check normalization (should have unit length)
412        let norm: f32 = results[0]
413            .embedding
414            .iter()
415            .map(|x| x * x)
416            .sum::<f32>()
417            .sqrt();
418        assert!((norm - 1.0).abs() < 0.01);
419    }
420}