ragfs_extract/
image.rs

1//! Image content extractor.
2//!
3//! Extracts image files and prepares them for embedding.
4//! Optionally generates captions using a vision model.
5
6use crate::vision::ImageCaptioner;
7use async_trait::async_trait;
8use image::GenericImageView;
9use ragfs_core::{
10    ContentElement, ContentExtractor, ContentMetadataInfo, ExtractError, ExtractedContent,
11    ExtractedImage,
12};
13use std::path::Path;
14use std::sync::Arc;
15use tracing::{debug, warn};
16
17/// Extractor for image files.
18///
19/// Optionally uses a vision captioner to generate descriptions.
20pub struct ImageExtractor {
21    /// Optional vision captioner for generating image descriptions.
22    captioner: Option<Arc<dyn ImageCaptioner>>,
23}
24
25impl ImageExtractor {
26    /// Create a new image extractor without captioning.
27    #[must_use]
28    pub fn new() -> Self {
29        Self { captioner: None }
30    }
31
32    /// Create a new image extractor with vision captioning.
33    #[must_use]
34    pub fn with_captioner(captioner: Arc<dyn ImageCaptioner>) -> Self {
35        Self {
36            captioner: Some(captioner),
37        }
38    }
39}
40
41impl Default for ImageExtractor {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[async_trait]
48impl ContentExtractor for ImageExtractor {
49    fn supported_types(&self) -> &[&str] {
50        &[
51            "image/jpeg",
52            "image/png",
53            "image/gif",
54            "image/webp",
55            "image/bmp",
56            "image/tiff",
57        ]
58    }
59
60    fn can_extract_by_extension(&self, path: &Path) -> bool {
61        let extensions = ["jpg", "jpeg", "png", "gif", "webp", "bmp", "tiff", "tif"];
62
63        path.extension()
64            .and_then(|ext| ext.to_str())
65            .is_some_and(|ext| extensions.contains(&ext.to_lowercase().as_str()))
66    }
67
68    async fn extract(&self, path: &Path) -> Result<ExtractedContent, ExtractError> {
69        debug!("Extracting image: {:?}", path);
70
71        // Read image file
72        let bytes = tokio::fs::read(path).await?;
73
74        // Decode image to get metadata (blocking operation)
75        let (width, height, format) =
76            tokio::task::spawn_blocking(move || decode_image_metadata(&bytes))
77                .await
78                .map_err(|e| ExtractError::Failed(format!("Task join error: {e}")))?
79                .map_err(|e| ExtractError::Failed(format!("Image decode failed: {e}")))?;
80
81        // Get MIME type
82        let mime_type = mime_type_from_extension(path);
83
84        // Read bytes again for storage (we consumed them in decoding)
85        let data = tokio::fs::read(path).await?;
86
87        // Create text representation with metadata
88        let text = format!(
89            "Image: {} ({}x{}, {})",
90            path.file_name()
91                .and_then(|n| n.to_str())
92                .unwrap_or("unknown"),
93            width,
94            height,
95            format
96        );
97
98        // Generate caption if captioner is available
99        let caption = if let Some(ref captioner) = self.captioner {
100            if captioner.is_initialized().await {
101                match captioner.caption(&data).await {
102                    Ok(cap) => cap,
103                    Err(e) => {
104                        warn!("Caption generation failed for {:?}: {}", path, e);
105                        None
106                    }
107                }
108            } else {
109                debug!("Captioner not initialized, skipping caption generation");
110                None
111            }
112        } else {
113            None
114        };
115
116        // Create the extracted image
117        let extracted_image = ExtractedImage {
118            data,
119            mime_type: mime_type.clone(),
120            caption,
121            page: None,
122        };
123
124        Ok(ExtractedContent {
125            text,
126            elements: vec![ContentElement::Paragraph {
127                text: format!("{width}x{height} {format} image"),
128                byte_offset: 0,
129            }],
130            images: vec![extracted_image],
131            metadata: ContentMetadataInfo {
132                title: path
133                    .file_name()
134                    .and_then(|n| n.to_str())
135                    .map(std::string::ToString::to_string),
136                ..Default::default()
137            },
138        })
139    }
140}
141
142/// Decode image to get dimensions and format.
143fn decode_image_metadata(bytes: &[u8]) -> Result<(u32, u32, String), String> {
144    let img = image::load_from_memory(bytes).map_err(|e| format!("Failed to load image: {e}"))?;
145
146    let (width, height) = img.dimensions();
147
148    // Try to detect format
149    let format = image::guess_format(bytes).map_or_else(
150        |_| "unknown".to_string(),
151        |f| format!("{f:?}").to_lowercase(),
152    );
153
154    Ok((width, height, format))
155}
156
157/// Get MIME type from file extension.
158fn mime_type_from_extension(path: &Path) -> String {
159    path.extension()
160        .and_then(|ext| ext.to_str())
161        .map_or("application/octet-stream", |ext| {
162            match ext.to_lowercase().as_str() {
163                "jpg" | "jpeg" => "image/jpeg",
164                "png" => "image/png",
165                "gif" => "image/gif",
166                "webp" => "image/webp",
167                "bmp" => "image/bmp",
168                "tiff" | "tif" => "image/tiff",
169                _ => "application/octet-stream",
170            }
171        })
172        .to_string()
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use tempfile::tempdir;
179
180    /// Create a simple 2x2 PNG image for testing
181    fn create_test_png() -> Vec<u8> {
182        use image::{ImageBuffer, Rgba};
183
184        let img: ImageBuffer<Rgba<u8>, Vec<u8>> = ImageBuffer::from_fn(2, 2, |x, y| {
185            if (x + y) % 2 == 0 {
186                Rgba([255, 0, 0, 255]) // Red
187            } else {
188                Rgba([0, 255, 0, 255]) // Green
189            }
190        });
191
192        let mut bytes: Vec<u8> = Vec::new();
193        let mut cursor = std::io::Cursor::new(&mut bytes);
194        img.write_to(&mut cursor, image::ImageFormat::Png).unwrap();
195        bytes
196    }
197
198    /// Create a simple 2x2 JPEG image for testing
199    fn create_test_jpeg() -> Vec<u8> {
200        use image::{ImageBuffer, Rgb};
201
202        let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_fn(2, 2, |x, y| {
203            if (x + y) % 2 == 0 {
204                Rgb([255, 0, 0]) // Red
205            } else {
206                Rgb([0, 255, 0]) // Green
207            }
208        });
209
210        let mut bytes: Vec<u8> = Vec::new();
211        let mut cursor = std::io::Cursor::new(&mut bytes);
212        img.write_to(&mut cursor, image::ImageFormat::Jpeg).unwrap();
213        bytes
214    }
215
216    #[test]
217    fn test_new_extractor() {
218        let extractor = ImageExtractor::new();
219        assert!(!extractor.supported_types().is_empty());
220    }
221
222    #[test]
223    fn test_default_implementation() {
224        let extractor = ImageExtractor::default();
225        assert!(!extractor.supported_types().is_empty());
226    }
227
228    #[test]
229    fn test_supported_types_includes_common_formats() {
230        let extractor = ImageExtractor::new();
231        let types = extractor.supported_types();
232
233        assert!(types.contains(&"image/jpeg"));
234        assert!(types.contains(&"image/png"));
235        assert!(types.contains(&"image/gif"));
236        assert!(types.contains(&"image/webp"));
237    }
238
239    #[test]
240    fn test_can_extract_by_extension() {
241        let extractor = ImageExtractor::new();
242
243        assert!(extractor.can_extract_by_extension(Path::new("photo.jpg")));
244        assert!(extractor.can_extract_by_extension(Path::new("image.PNG")));
245        assert!(extractor.can_extract_by_extension(Path::new("animation.gif")));
246        assert!(extractor.can_extract_by_extension(Path::new("photo.webp")));
247        assert!(!extractor.can_extract_by_extension(Path::new("document.txt")));
248        assert!(!extractor.can_extract_by_extension(Path::new("code.rs")));
249    }
250
251    #[test]
252    fn test_can_extract_jpeg_variants() {
253        let extractor = ImageExtractor::new();
254
255        assert!(extractor.can_extract_by_extension(Path::new("photo.jpg")));
256        assert!(extractor.can_extract_by_extension(Path::new("photo.jpeg")));
257        assert!(extractor.can_extract_by_extension(Path::new("photo.JPG")));
258        assert!(extractor.can_extract_by_extension(Path::new("photo.JPEG")));
259    }
260
261    #[test]
262    fn test_can_extract_tiff_variants() {
263        let extractor = ImageExtractor::new();
264
265        assert!(extractor.can_extract_by_extension(Path::new("image.tiff")));
266        assert!(extractor.can_extract_by_extension(Path::new("image.tif")));
267    }
268
269    #[test]
270    fn test_cannot_extract_non_image() {
271        let extractor = ImageExtractor::new();
272
273        assert!(!extractor.can_extract_by_extension(Path::new("file.pdf")));
274        assert!(!extractor.can_extract_by_extension(Path::new("file.mp4")));
275        assert!(!extractor.can_extract_by_extension(Path::new("file.zip")));
276    }
277
278    #[test]
279    fn test_mime_type_from_extension() {
280        assert_eq!(
281            mime_type_from_extension(Path::new("test.jpg")),
282            "image/jpeg"
283        );
284        assert_eq!(mime_type_from_extension(Path::new("test.png")), "image/png");
285        assert_eq!(mime_type_from_extension(Path::new("test.gif")), "image/gif");
286    }
287
288    #[test]
289    fn test_mime_type_from_extension_case_insensitive() {
290        assert_eq!(
291            mime_type_from_extension(Path::new("test.JPG")),
292            "image/jpeg"
293        );
294        assert_eq!(mime_type_from_extension(Path::new("test.PNG")), "image/png");
295    }
296
297    #[test]
298    fn test_mime_type_unknown_extension() {
299        assert_eq!(
300            mime_type_from_extension(Path::new("test.xyz")),
301            "application/octet-stream"
302        );
303    }
304
305    #[tokio::test]
306    async fn test_extract_png_image() {
307        let temp_dir = tempdir().unwrap();
308        let file_path = temp_dir.path().join("test.png");
309        std::fs::write(&file_path, create_test_png()).unwrap();
310
311        let extractor = ImageExtractor::new();
312        let result = extractor.extract(&file_path).await;
313
314        assert!(result.is_ok());
315        let content = result.unwrap();
316        assert!(content.text.contains("test.png"));
317        assert!(content.text.contains("2x2")); // Dimensions
318        assert_eq!(content.images.len(), 1);
319        assert_eq!(content.images[0].mime_type, "image/png");
320    }
321
322    #[tokio::test]
323    async fn test_extract_jpeg_image() {
324        let temp_dir = tempdir().unwrap();
325        let file_path = temp_dir.path().join("test.jpg");
326        std::fs::write(&file_path, create_test_jpeg()).unwrap();
327
328        let extractor = ImageExtractor::new();
329        let result = extractor.extract(&file_path).await;
330
331        assert!(result.is_ok());
332        let content = result.unwrap();
333        assert!(content.text.contains("test.jpg"));
334        assert_eq!(content.images.len(), 1);
335        assert_eq!(content.images[0].mime_type, "image/jpeg");
336    }
337
338    #[tokio::test]
339    async fn test_extract_creates_paragraph_element() {
340        let temp_dir = tempdir().unwrap();
341        let file_path = temp_dir.path().join("photo.png");
342        std::fs::write(&file_path, create_test_png()).unwrap();
343
344        let extractor = ImageExtractor::new();
345        let content = extractor.extract(&file_path).await.unwrap();
346
347        assert_eq!(content.elements.len(), 1);
348        match &content.elements[0] {
349            ragfs_core::ContentElement::Paragraph { text, .. } => {
350                assert!(text.contains("2x2"));
351                assert!(text.contains("image"));
352            }
353            _ => panic!("Expected Paragraph element"),
354        }
355    }
356
357    #[tokio::test]
358    async fn test_extract_stores_image_data() {
359        let temp_dir = tempdir().unwrap();
360        let file_path = temp_dir.path().join("image.png");
361        let png_data = create_test_png();
362        std::fs::write(&file_path, &png_data).unwrap();
363
364        let extractor = ImageExtractor::new();
365        let content = extractor.extract(&file_path).await.unwrap();
366
367        assert_eq!(content.images.len(), 1);
368        assert_eq!(content.images[0].data, png_data);
369    }
370
371    #[tokio::test]
372    async fn test_extract_sets_title_metadata() {
373        let temp_dir = tempdir().unwrap();
374        let file_path = temp_dir.path().join("my_photo.png");
375        std::fs::write(&file_path, create_test_png()).unwrap();
376
377        let extractor = ImageExtractor::new();
378        let content = extractor.extract(&file_path).await.unwrap();
379
380        assert_eq!(content.metadata.title, Some("my_photo.png".to_string()));
381    }
382
383    #[tokio::test]
384    async fn test_extract_nonexistent_file_fails() {
385        let extractor = ImageExtractor::new();
386        let result = extractor.extract(Path::new("/nonexistent/image.png")).await;
387
388        assert!(result.is_err());
389    }
390
391    #[tokio::test]
392    async fn test_extract_invalid_image_fails() {
393        let temp_dir = tempdir().unwrap();
394        let file_path = temp_dir.path().join("fake.png");
395        // Write non-image data
396        std::fs::write(&file_path, b"This is not an image").unwrap();
397
398        let extractor = ImageExtractor::new();
399        let result = extractor.extract(&file_path).await;
400
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn test_decode_image_metadata() {
406        let png_data = create_test_png();
407        let result = decode_image_metadata(&png_data);
408
409        assert!(result.is_ok());
410        let (width, height, format) = result.unwrap();
411        assert_eq!(width, 2);
412        assert_eq!(height, 2);
413        assert!(format.contains("png"));
414    }
415
416    #[test]
417    fn test_decode_image_metadata_invalid() {
418        let result = decode_image_metadata(b"not an image");
419        assert!(result.is_err());
420    }
421}