1use async_trait::async_trait;
8use std::path::PathBuf;
9use std::sync::Arc;
10use thiserror::Error;
11use tokio::sync::RwLock;
12use tracing::debug;
13
14#[cfg(feature = "vision")]
15use candle_core::{DType, Device, IndexOp, Module, Tensor};
16#[cfg(feature = "vision")]
17use candle_nn::VarBuilder;
18#[cfg(feature = "vision")]
19use candle_transformers::models::blip;
20#[cfg(feature = "vision")]
21use hf_hub::{Repo, RepoType, api::tokio::Api};
22#[cfg(feature = "vision")]
23use tokenizers::Tokenizer;
24#[cfg(feature = "vision")]
25use tracing::info;
26
27#[derive(Debug, Error)]
29pub enum CaptionError {
30 #[error("model loading failed: {0}")]
32 ModelLoad(String),
33
34 #[error("image preprocessing failed: {0}")]
36 ImagePreprocess(String),
37
38 #[error("caption generation failed: {0}")]
40 Generation(String),
41
42 #[error("model not initialized")]
44 NotInitialized,
45
46 #[error("io error: {0}")]
48 Io(#[from] std::io::Error),
49}
50
51#[async_trait]
53pub trait ImageCaptioner: Send + Sync {
54 async fn init(&self) -> Result<(), CaptionError>;
56
57 async fn caption(&self, image_data: &[u8]) -> Result<Option<String>, CaptionError>;
59
60 async fn is_initialized(&self) -> bool;
62
63 fn model_name(&self) -> &str;
65}
66
67pub struct PlaceholderCaptioner {
72 initialized: Arc<RwLock<bool>>,
73}
74
75impl PlaceholderCaptioner {
76 #[must_use]
78 pub fn new() -> Self {
79 Self {
80 initialized: Arc::new(RwLock::new(false)),
81 }
82 }
83}
84
85impl Default for PlaceholderCaptioner {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91#[async_trait]
92impl ImageCaptioner for PlaceholderCaptioner {
93 async fn init(&self) -> Result<(), CaptionError> {
94 let mut initialized = self.initialized.write().await;
95 *initialized = true;
96 debug!("Placeholder captioner initialized (no-op)");
97 Ok(())
98 }
99
100 async fn caption(&self, _image_data: &[u8]) -> Result<Option<String>, CaptionError> {
101 Ok(None)
103 }
104
105 async fn is_initialized(&self) -> bool {
106 *self.initialized.read().await
107 }
108
109 fn model_name(&self) -> &str {
110 "placeholder"
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct CaptionConfig {
117 pub enabled: bool,
119 pub quantized: bool,
121 pub max_tokens: usize,
123 pub cache_dir: PathBuf,
125}
126
127impl Default for CaptionConfig {
128 fn default() -> Self {
129 Self {
130 enabled: false,
131 quantized: false,
132 max_tokens: 100,
133 cache_dir: PathBuf::from("~/.local/share/ragfs/models"),
134 }
135 }
136}
137
138#[cfg(feature = "vision")]
144#[allow(clippy::doc_markdown)]
145const BLIP_MODEL_ID: &str = "Salesforce/blip-image-captioning-base";
146
147#[cfg(feature = "vision")]
149const BLIP_IMAGE_SIZE: u32 = 384;
150
151#[cfg(feature = "vision")]
156#[allow(clippy::doc_markdown)]
157pub struct BlipCaptioner {
158 device: Device,
160 model: Arc<RwLock<Option<blip::BlipForConditionalGeneration>>>,
162 tokenizer: Arc<RwLock<Option<Tokenizer>>>,
164 #[allow(dead_code)]
166 cache_dir: PathBuf,
167 initialized: Arc<RwLock<bool>>,
169 config: CaptionConfig,
171}
172
173#[cfg(feature = "vision")]
174impl BlipCaptioner {
175 #[must_use]
177 pub fn new(cache_dir: PathBuf) -> Self {
178 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
179 info!("BlipCaptioner using device: {:?}", device);
180
181 Self {
182 device,
183 model: Arc::new(RwLock::new(None)),
184 tokenizer: Arc::new(RwLock::new(None)),
185 cache_dir,
186 initialized: Arc::new(RwLock::new(false)),
187 config: CaptionConfig::default(),
188 }
189 }
190
191 #[must_use]
193 pub fn with_config(cache_dir: PathBuf, config: CaptionConfig) -> Self {
194 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
195 info!("BlipCaptioner using device: {:?}", device);
196
197 Self {
198 device,
199 model: Arc::new(RwLock::new(None)),
200 tokenizer: Arc::new(RwLock::new(None)),
201 cache_dir,
202 initialized: Arc::new(RwLock::new(false)),
203 config,
204 }
205 }
206
207 fn preprocess_image(&self, image_data: &[u8]) -> Result<Tensor, CaptionError> {
209 let img = image::load_from_memory(image_data)
211 .map_err(|e| CaptionError::ImagePreprocess(format!("Failed to decode image: {e}")))?;
212
213 let img = img.resize_exact(
215 BLIP_IMAGE_SIZE,
216 BLIP_IMAGE_SIZE,
217 image::imageops::FilterType::Triangle,
218 );
219
220 let img = img.to_rgb8();
222 let (width, height) = (img.width() as usize, img.height() as usize);
223
224 #[allow(clippy::excessive_precision, clippy::unreadable_literal)]
226 let mean = [0.48145466f32, 0.4578275, 0.40821073];
227 #[allow(clippy::excessive_precision, clippy::unreadable_literal)]
228 let std = [0.26862954f32, 0.26130258, 0.27577711];
229
230 let mut data = vec![0f32; 3 * width * height];
232 for (x, y, pixel) in img.enumerate_pixels() {
233 let x = x as usize;
234 let y = y as usize;
235 for c in 0..3 {
236 let val = f32::from(pixel[c]) / 255.0;
237 let normalized = (val - mean[c]) / std[c];
238 data[c * height * width + y * width + x] = normalized;
239 }
240 }
241
242 let tensor = Tensor::from_vec(data, (3, height, width), &self.device)
243 .map_err(|e| CaptionError::ImagePreprocess(format!("Tensor creation failed: {e}")))?
244 .unsqueeze(0) .map_err(|e| CaptionError::ImagePreprocess(format!("Unsqueeze failed: {e}")))?;
246
247 Ok(tensor)
248 }
249
250 async fn generate_caption(&self, image_tensor: &Tensor) -> Result<String, CaptionError> {
252 let mut model_guard = self.model.write().await;
254 let model = model_guard.as_mut().ok_or(CaptionError::NotInitialized)?;
255
256 let tokenizer_guard = self.tokenizer.read().await;
257 let tokenizer = tokenizer_guard
258 .as_ref()
259 .ok_or(CaptionError::NotInitialized)?;
260
261 let image_embeds = model
263 .vision_model()
264 .forward(image_tensor)
265 .map_err(|e| CaptionError::Generation(format!("Vision forward failed: {e}")))?;
266
267 let mut token_ids = vec![tokenizer.token_to_id("[CLS]").unwrap_or(101)]; let eos_token_id = tokenizer.token_to_id("[SEP]").unwrap_or(102);
271 let max_tokens = self.config.max_tokens;
272
273 for _ in 0..max_tokens {
275 let input_ids = Tensor::new(&token_ids[..], &self.device)
276 .map_err(|e| CaptionError::Generation(format!("Token tensor failed: {e}")))?
277 .unsqueeze(0)
278 .map_err(|e| CaptionError::Generation(format!("Unsqueeze failed: {e}")))?;
279
280 let logits = model
281 .text_decoder()
282 .forward(&input_ids, &image_embeds)
283 .map_err(|e| CaptionError::Generation(format!("Text decoder failed: {e}")))?;
284
285 let seq_len = logits
287 .dim(1)
288 .map_err(|e| CaptionError::Generation(format!("Dim failed: {e}")))?;
289 let next_token_logits = logits
290 .i((.., seq_len - 1, ..))
291 .map_err(|e| CaptionError::Generation(format!("Index failed: {e}")))?;
292
293 let next_token = next_token_logits
294 .argmax(candle_core::D::Minus1)
295 .map_err(|e| CaptionError::Generation(format!("Argmax failed: {e}")))?
296 .to_scalar::<u32>()
297 .map_err(|e| CaptionError::Generation(format!("Scalar failed: {e}")))?;
298
299 if next_token == eos_token_id {
300 break;
301 }
302
303 token_ids.push(next_token);
304 }
305
306 let caption = tokenizer
308 .decode(&token_ids, true)
309 .map_err(|e| CaptionError::Generation(format!("Decode failed: {e}")))?;
310
311 Ok(caption.trim().to_string())
312 }
313}
314
315#[cfg(feature = "vision")]
316#[async_trait]
317impl ImageCaptioner for BlipCaptioner {
318 async fn init(&self) -> Result<(), CaptionError> {
319 {
320 let initialized = self.initialized.read().await;
321 if *initialized {
322 return Ok(());
323 }
324 }
325
326 info!("Initializing BlipCaptioner with model: {}", BLIP_MODEL_ID);
327
328 let api = Api::new()
330 .map_err(|e| CaptionError::ModelLoad(format!("Failed to create HF API: {e}")))?;
331
332 let repo = api.repo(Repo::new(BLIP_MODEL_ID.to_string(), RepoType::Model));
333
334 debug!("Downloading tokenizer...");
336 let tokenizer_path = repo
337 .get("tokenizer.json")
338 .await
339 .map_err(|e| CaptionError::ModelLoad(format!("Failed to download tokenizer: {e}")))?;
340
341 debug!("Downloading config...");
343 let config_path = repo
344 .get("config.json")
345 .await
346 .map_err(|e| CaptionError::ModelLoad(format!("Failed to download config: {e}")))?;
347
348 debug!("Downloading model weights...");
350 let weights_path = repo
351 .get("model.safetensors")
352 .await
353 .map_err(|e| CaptionError::ModelLoad(format!("Failed to download weights: {e}")))?;
354
355 debug!("Loading tokenizer...");
357 let tokenizer = Tokenizer::from_file(&tokenizer_path)
358 .map_err(|e| CaptionError::ModelLoad(format!("Failed to load tokenizer: {e}")))?;
359
360 debug!("Loading config...");
362 let config_str = std::fs::read_to_string(&config_path)
363 .map_err(|e| CaptionError::ModelLoad(format!("Failed to read config: {e}")))?;
364 let config: blip::Config = serde_json::from_str(&config_str)
365 .map_err(|e| CaptionError::ModelLoad(format!("Failed to parse config: {e}")))?;
366
367 debug!("Loading model weights...");
369 let dtype = if self.config.quantized {
370 DType::BF16
371 } else {
372 DType::F32
373 };
374
375 #[allow(unsafe_code)]
377 let vb = unsafe {
378 VarBuilder::from_mmaped_safetensors(&[weights_path], dtype, &self.device)
379 .map_err(|e| CaptionError::ModelLoad(format!("Failed to load weights: {e}")))?
380 };
381
382 let model = blip::BlipForConditionalGeneration::new(&config, vb)
383 .map_err(|e| CaptionError::ModelLoad(format!("Failed to create BLIP model: {e}")))?;
384
385 {
387 let mut tok = self.tokenizer.write().await;
388 *tok = Some(tokenizer);
389 }
390 {
391 let mut mdl = self.model.write().await;
392 *mdl = Some(model);
393 }
394 {
395 let mut initialized = self.initialized.write().await;
396 *initialized = true;
397 }
398
399 info!("BlipCaptioner initialized successfully");
400 Ok(())
401 }
402
403 async fn caption(&self, image_data: &[u8]) -> Result<Option<String>, CaptionError> {
404 if !self.is_initialized().await {
405 return Err(CaptionError::NotInitialized);
406 }
407
408 if !self.config.enabled {
409 return Ok(None);
410 }
411
412 debug!("Generating caption for image ({} bytes)", image_data.len());
413
414 let image_tensor = self.preprocess_image(image_data)?;
416
417 let caption = self.generate_caption(&image_tensor).await?;
419
420 if caption.is_empty() {
421 Ok(None)
422 } else {
423 Ok(Some(caption))
424 }
425 }
426
427 async fn is_initialized(&self) -> bool {
428 *self.initialized.read().await
429 }
430
431 fn model_name(&self) -> &str {
432 BLIP_MODEL_ID
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[tokio::test]
441 async fn test_placeholder_captioner_new() {
442 let captioner = PlaceholderCaptioner::new();
443 assert!(!captioner.is_initialized().await);
444 }
445
446 #[tokio::test]
447 async fn test_placeholder_captioner_init() {
448 let captioner = PlaceholderCaptioner::new();
449 let result = captioner.init().await;
450 assert!(result.is_ok());
451 assert!(captioner.is_initialized().await);
452 }
453
454 #[tokio::test]
455 async fn test_placeholder_captioner_returns_none() {
456 let captioner = PlaceholderCaptioner::new();
457 captioner.init().await.unwrap();
458
459 let result = captioner.caption(b"fake image data").await;
460 assert!(result.is_ok());
461 assert!(result.unwrap().is_none());
462 }
463
464 #[tokio::test]
465 async fn test_placeholder_captioner_model_name() {
466 let captioner = PlaceholderCaptioner::new();
467 assert_eq!(captioner.model_name(), "placeholder");
468 }
469
470 #[test]
471 fn test_caption_config_default() {
472 let config = CaptionConfig::default();
473 assert!(!config.enabled);
474 assert!(!config.quantized);
475 assert_eq!(config.max_tokens, 100);
476 }
477
478 #[test]
479 fn test_caption_error_display() {
480 let err = CaptionError::NotInitialized;
481 assert_eq!(err.to_string(), "model not initialized");
482
483 let err = CaptionError::ModelLoad("test error".to_string());
484 assert!(err.to_string().contains("model loading failed"));
485 }
486}