ragfs_extract/
vision.rs

1//! Vision model captioning for images.
2//!
3//! This module provides infrastructure for generating captions from images
4//! using vision models. With the `vision` feature enabled, a BLIP-based
5//! captioner is available. Otherwise, only a placeholder implementation exists.
6
7use async_trait::async_trait;
8use std::path::PathBuf;
9use std::sync::Arc;
10use thiserror::Error;
11use tokio::sync::RwLock;
12use tracing::debug;
13
14#[cfg(feature = "vision")]
15use candle_core::{DType, Device, IndexOp, Module, Tensor};
16#[cfg(feature = "vision")]
17use candle_nn::VarBuilder;
18#[cfg(feature = "vision")]
19use candle_transformers::models::blip;
20#[cfg(feature = "vision")]
21use hf_hub::{Repo, RepoType, api::tokio::Api};
22#[cfg(feature = "vision")]
23use tokenizers::Tokenizer;
24#[cfg(feature = "vision")]
25use tracing::info;
26
27/// Error type for vision captioning operations.
28#[derive(Debug, Error)]
29pub enum CaptionError {
30    /// Model loading failed.
31    #[error("model loading failed: {0}")]
32    ModelLoad(String),
33
34    /// Image preprocessing failed.
35    #[error("image preprocessing failed: {0}")]
36    ImagePreprocess(String),
37
38    /// Caption generation failed.
39    #[error("caption generation failed: {0}")]
40    Generation(String),
41
42    /// Model not initialized.
43    #[error("model not initialized")]
44    NotInitialized,
45
46    /// IO error.
47    #[error("io error: {0}")]
48    Io(#[from] std::io::Error),
49}
50
51/// Trait for vision-based image captioning.
52#[async_trait]
53pub trait ImageCaptioner: Send + Sync {
54    /// Initialize the captioner (load model, etc.).
55    async fn init(&self) -> Result<(), CaptionError>;
56
57    /// Generate a caption for image bytes.
58    async fn caption(&self, image_data: &[u8]) -> Result<Option<String>, CaptionError>;
59
60    /// Check if the captioner is initialized.
61    async fn is_initialized(&self) -> bool;
62
63    /// Get the model name.
64    fn model_name(&self) -> &str;
65}
66
67/// Placeholder vision captioner that returns no captions.
68///
69/// This is a no-op implementation that can be used when vision captioning
70/// is not available or not desired. It always returns `None` for captions.
71pub struct PlaceholderCaptioner {
72    initialized: Arc<RwLock<bool>>,
73}
74
75impl PlaceholderCaptioner {
76    /// Create a new placeholder captioner.
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            initialized: Arc::new(RwLock::new(false)),
81        }
82    }
83}
84
85impl Default for PlaceholderCaptioner {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91#[async_trait]
92impl ImageCaptioner for PlaceholderCaptioner {
93    async fn init(&self) -> Result<(), CaptionError> {
94        let mut initialized = self.initialized.write().await;
95        *initialized = true;
96        debug!("Placeholder captioner initialized (no-op)");
97        Ok(())
98    }
99
100    async fn caption(&self, _image_data: &[u8]) -> Result<Option<String>, CaptionError> {
101        // Placeholder returns no caption
102        Ok(None)
103    }
104
105    async fn is_initialized(&self) -> bool {
106        *self.initialized.read().await
107    }
108
109    fn model_name(&self) -> &str {
110        "placeholder"
111    }
112}
113
114/// Configuration for vision captioning.
115#[derive(Debug, Clone)]
116pub struct CaptionConfig {
117    /// Enable captioning (default: false until model is implemented).
118    pub enabled: bool,
119    /// Use quantized model for lower memory usage.
120    pub quantized: bool,
121    /// Maximum tokens to generate.
122    pub max_tokens: usize,
123    /// Cache directory for model files.
124    pub cache_dir: PathBuf,
125}
126
127impl Default for CaptionConfig {
128    fn default() -> Self {
129        Self {
130            enabled: false,
131            quantized: false,
132            max_tokens: 100,
133            cache_dir: PathBuf::from("~/.local/share/ragfs/models"),
134        }
135    }
136}
137
138// ============================================================================
139// BLIP Captioner (requires "vision" feature)
140// ============================================================================
141
142/// BLIP model identifier on HuggingFace Hub.
143#[cfg(feature = "vision")]
144#[allow(clippy::doc_markdown)]
145const BLIP_MODEL_ID: &str = "Salesforce/blip-image-captioning-base";
146
147/// Image size for BLIP preprocessing.
148#[cfg(feature = "vision")]
149const BLIP_IMAGE_SIZE: u32 = 384;
150
151/// BLIP-based image captioner using Candle.
152///
153/// Uses the `Salesforce/blip-image-captioning-base` model from HuggingFace Hub.
154/// Requires the `vision` feature to be enabled.
155#[cfg(feature = "vision")]
156#[allow(clippy::doc_markdown)]
157pub struct BlipCaptioner {
158    /// Device to run inference on (CPU or CUDA)
159    device: Device,
160    /// BLIP model
161    model: Arc<RwLock<Option<blip::BlipForConditionalGeneration>>>,
162    /// Text tokenizer
163    tokenizer: Arc<RwLock<Option<Tokenizer>>>,
164    /// Cache directory for model files
165    #[allow(dead_code)]
166    cache_dir: PathBuf,
167    /// Whether model is initialized
168    initialized: Arc<RwLock<bool>>,
169    /// Configuration
170    config: CaptionConfig,
171}
172
173#[cfg(feature = "vision")]
174impl BlipCaptioner {
175    /// Create a new BLIP captioner.
176    #[must_use]
177    pub fn new(cache_dir: PathBuf) -> Self {
178        let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
179        info!("BlipCaptioner using device: {:?}", device);
180
181        Self {
182            device,
183            model: Arc::new(RwLock::new(None)),
184            tokenizer: Arc::new(RwLock::new(None)),
185            cache_dir,
186            initialized: Arc::new(RwLock::new(false)),
187            config: CaptionConfig::default(),
188        }
189    }
190
191    /// Create with custom configuration.
192    #[must_use]
193    pub fn with_config(cache_dir: PathBuf, config: CaptionConfig) -> Self {
194        let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
195        info!("BlipCaptioner using device: {:?}", device);
196
197        Self {
198            device,
199            model: Arc::new(RwLock::new(None)),
200            tokenizer: Arc::new(RwLock::new(None)),
201            cache_dir,
202            initialized: Arc::new(RwLock::new(false)),
203            config,
204        }
205    }
206
207    /// Preprocess image data into a tensor suitable for BLIP.
208    fn preprocess_image(&self, image_data: &[u8]) -> Result<Tensor, CaptionError> {
209        // Decode image
210        let img = image::load_from_memory(image_data)
211            .map_err(|e| CaptionError::ImagePreprocess(format!("Failed to decode image: {e}")))?;
212
213        // Resize to 384x384
214        let img = img.resize_exact(
215            BLIP_IMAGE_SIZE,
216            BLIP_IMAGE_SIZE,
217            image::imageops::FilterType::Triangle,
218        );
219
220        // Convert to RGB and normalize
221        let img = img.to_rgb8();
222        let (width, height) = (img.width() as usize, img.height() as usize);
223
224        // CLIP normalization values (standard ImageNet values, kept as-is for reference)
225        #[allow(clippy::excessive_precision, clippy::unreadable_literal)]
226        let mean = [0.48145466f32, 0.4578275, 0.40821073];
227        #[allow(clippy::excessive_precision, clippy::unreadable_literal)]
228        let std = [0.26862954f32, 0.26130258, 0.27577711];
229
230        // Convert to tensor [C, H, W] format with normalization
231        let mut data = vec![0f32; 3 * width * height];
232        for (x, y, pixel) in img.enumerate_pixels() {
233            let x = x as usize;
234            let y = y as usize;
235            for c in 0..3 {
236                let val = f32::from(pixel[c]) / 255.0;
237                let normalized = (val - mean[c]) / std[c];
238                data[c * height * width + y * width + x] = normalized;
239            }
240        }
241
242        let tensor = Tensor::from_vec(data, (3, height, width), &self.device)
243            .map_err(|e| CaptionError::ImagePreprocess(format!("Tensor creation failed: {e}")))?
244            .unsqueeze(0) // Add batch dimension [1, C, H, W]
245            .map_err(|e| CaptionError::ImagePreprocess(format!("Unsqueeze failed: {e}")))?;
246
247        Ok(tensor)
248    }
249
250    /// Generate caption from image embedding.
251    async fn generate_caption(&self, image_tensor: &Tensor) -> Result<String, CaptionError> {
252        // Use write lock since text_decoder may need mutable access
253        let mut model_guard = self.model.write().await;
254        let model = model_guard.as_mut().ok_or(CaptionError::NotInitialized)?;
255
256        let tokenizer_guard = self.tokenizer.read().await;
257        let tokenizer = tokenizer_guard
258            .as_ref()
259            .ok_or(CaptionError::NotInitialized)?;
260
261        // Get image embeddings from vision encoder
262        let image_embeds = model
263            .vision_model()
264            .forward(image_tensor)
265            .map_err(|e| CaptionError::Generation(format!("Vision forward failed: {e}")))?;
266
267        // Initialize with BOS token
268        let mut token_ids = vec![tokenizer.token_to_id("[CLS]").unwrap_or(101)]; // Default BERT [CLS] id
269
270        let eos_token_id = tokenizer.token_to_id("[SEP]").unwrap_or(102);
271        let max_tokens = self.config.max_tokens;
272
273        // Autoregressive generation
274        for _ in 0..max_tokens {
275            let input_ids = Tensor::new(&token_ids[..], &self.device)
276                .map_err(|e| CaptionError::Generation(format!("Token tensor failed: {e}")))?
277                .unsqueeze(0)
278                .map_err(|e| CaptionError::Generation(format!("Unsqueeze failed: {e}")))?;
279
280            let logits = model
281                .text_decoder()
282                .forward(&input_ids, &image_embeds)
283                .map_err(|e| CaptionError::Generation(format!("Text decoder failed: {e}")))?;
284
285            // Get next token (greedy decoding)
286            let seq_len = logits
287                .dim(1)
288                .map_err(|e| CaptionError::Generation(format!("Dim failed: {e}")))?;
289            let next_token_logits = logits
290                .i((.., seq_len - 1, ..))
291                .map_err(|e| CaptionError::Generation(format!("Index failed: {e}")))?;
292
293            let next_token = next_token_logits
294                .argmax(candle_core::D::Minus1)
295                .map_err(|e| CaptionError::Generation(format!("Argmax failed: {e}")))?
296                .to_scalar::<u32>()
297                .map_err(|e| CaptionError::Generation(format!("Scalar failed: {e}")))?;
298
299            if next_token == eos_token_id {
300                break;
301            }
302
303            token_ids.push(next_token);
304        }
305
306        // Decode tokens to string
307        let caption = tokenizer
308            .decode(&token_ids, true)
309            .map_err(|e| CaptionError::Generation(format!("Decode failed: {e}")))?;
310
311        Ok(caption.trim().to_string())
312    }
313}
314
315#[cfg(feature = "vision")]
316#[async_trait]
317impl ImageCaptioner for BlipCaptioner {
318    async fn init(&self) -> Result<(), CaptionError> {
319        {
320            let initialized = self.initialized.read().await;
321            if *initialized {
322                return Ok(());
323            }
324        }
325
326        info!("Initializing BlipCaptioner with model: {}", BLIP_MODEL_ID);
327
328        // Download model files from HuggingFace Hub
329        let api = Api::new()
330            .map_err(|e| CaptionError::ModelLoad(format!("Failed to create HF API: {e}")))?;
331
332        let repo = api.repo(Repo::new(BLIP_MODEL_ID.to_string(), RepoType::Model));
333
334        // Download tokenizer
335        debug!("Downloading tokenizer...");
336        let tokenizer_path = repo
337            .get("tokenizer.json")
338            .await
339            .map_err(|e| CaptionError::ModelLoad(format!("Failed to download tokenizer: {e}")))?;
340
341        // Download model config
342        debug!("Downloading config...");
343        let config_path = repo
344            .get("config.json")
345            .await
346            .map_err(|e| CaptionError::ModelLoad(format!("Failed to download config: {e}")))?;
347
348        // Download model weights
349        debug!("Downloading model weights...");
350        let weights_path = repo
351            .get("model.safetensors")
352            .await
353            .map_err(|e| CaptionError::ModelLoad(format!("Failed to download weights: {e}")))?;
354
355        // Load tokenizer
356        debug!("Loading tokenizer...");
357        let tokenizer = Tokenizer::from_file(&tokenizer_path)
358            .map_err(|e| CaptionError::ModelLoad(format!("Failed to load tokenizer: {e}")))?;
359
360        // Load config
361        debug!("Loading config...");
362        let config_str = std::fs::read_to_string(&config_path)
363            .map_err(|e| CaptionError::ModelLoad(format!("Failed to read config: {e}")))?;
364        let config: blip::Config = serde_json::from_str(&config_str)
365            .map_err(|e| CaptionError::ModelLoad(format!("Failed to parse config: {e}")))?;
366
367        // Load model weights
368        debug!("Loading model weights...");
369        let dtype = if self.config.quantized {
370            DType::BF16
371        } else {
372            DType::F32
373        };
374
375        // SAFETY: The safetensors file is downloaded from HuggingFace Hub and is trusted.
376        #[allow(unsafe_code)]
377        let vb = unsafe {
378            VarBuilder::from_mmaped_safetensors(&[weights_path], dtype, &self.device)
379                .map_err(|e| CaptionError::ModelLoad(format!("Failed to load weights: {e}")))?
380        };
381
382        let model = blip::BlipForConditionalGeneration::new(&config, vb)
383            .map_err(|e| CaptionError::ModelLoad(format!("Failed to create BLIP model: {e}")))?;
384
385        // Store in instance
386        {
387            let mut tok = self.tokenizer.write().await;
388            *tok = Some(tokenizer);
389        }
390        {
391            let mut mdl = self.model.write().await;
392            *mdl = Some(model);
393        }
394        {
395            let mut initialized = self.initialized.write().await;
396            *initialized = true;
397        }
398
399        info!("BlipCaptioner initialized successfully");
400        Ok(())
401    }
402
403    async fn caption(&self, image_data: &[u8]) -> Result<Option<String>, CaptionError> {
404        if !self.is_initialized().await {
405            return Err(CaptionError::NotInitialized);
406        }
407
408        if !self.config.enabled {
409            return Ok(None);
410        }
411
412        debug!("Generating caption for image ({} bytes)", image_data.len());
413
414        // Preprocess image
415        let image_tensor = self.preprocess_image(image_data)?;
416
417        // Generate caption
418        let caption = self.generate_caption(&image_tensor).await?;
419
420        if caption.is_empty() {
421            Ok(None)
422        } else {
423            Ok(Some(caption))
424        }
425    }
426
427    async fn is_initialized(&self) -> bool {
428        *self.initialized.read().await
429    }
430
431    fn model_name(&self) -> &str {
432        BLIP_MODEL_ID
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[tokio::test]
441    async fn test_placeholder_captioner_new() {
442        let captioner = PlaceholderCaptioner::new();
443        assert!(!captioner.is_initialized().await);
444    }
445
446    #[tokio::test]
447    async fn test_placeholder_captioner_init() {
448        let captioner = PlaceholderCaptioner::new();
449        let result = captioner.init().await;
450        assert!(result.is_ok());
451        assert!(captioner.is_initialized().await);
452    }
453
454    #[tokio::test]
455    async fn test_placeholder_captioner_returns_none() {
456        let captioner = PlaceholderCaptioner::new();
457        captioner.init().await.unwrap();
458
459        let result = captioner.caption(b"fake image data").await;
460        assert!(result.is_ok());
461        assert!(result.unwrap().is_none());
462    }
463
464    #[tokio::test]
465    async fn test_placeholder_captioner_model_name() {
466        let captioner = PlaceholderCaptioner::new();
467        assert_eq!(captioner.model_name(), "placeholder");
468    }
469
470    #[test]
471    fn test_caption_config_default() {
472        let config = CaptionConfig::default();
473        assert!(!config.enabled);
474        assert!(!config.quantized);
475        assert_eq!(config.max_tokens, 100);
476    }
477
478    #[test]
479    fn test_caption_error_display() {
480        let err = CaptionError::NotInitialized;
481        assert_eq!(err.to_string(), "model not initialized");
482
483        let err = CaptionError::ModelLoad("test error".to_string());
484        assert!(err.to_string().contains("model loading failed"));
485    }
486}