1use crate::vision::ImageCaptioner;
7use async_trait::async_trait;
8use image::GenericImageView;
9use ragfs_core::{
10 ContentElement, ContentExtractor, ContentMetadataInfo, ExtractError, ExtractedContent,
11 ExtractedImage,
12};
13use std::path::Path;
14use std::sync::Arc;
15use tracing::{debug, warn};
16
17pub struct ImageExtractor {
21 captioner: Option<Arc<dyn ImageCaptioner>>,
23}
24
25impl ImageExtractor {
26 #[must_use]
28 pub fn new() -> Self {
29 Self { captioner: None }
30 }
31
32 #[must_use]
34 pub fn with_captioner(captioner: Arc<dyn ImageCaptioner>) -> Self {
35 Self {
36 captioner: Some(captioner),
37 }
38 }
39}
40
41impl Default for ImageExtractor {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47#[async_trait]
48impl ContentExtractor for ImageExtractor {
49 fn supported_types(&self) -> &[&str] {
50 &[
51 "image/jpeg",
52 "image/png",
53 "image/gif",
54 "image/webp",
55 "image/bmp",
56 "image/tiff",
57 ]
58 }
59
60 fn can_extract_by_extension(&self, path: &Path) -> bool {
61 let extensions = ["jpg", "jpeg", "png", "gif", "webp", "bmp", "tiff", "tif"];
62
63 path.extension()
64 .and_then(|ext| ext.to_str())
65 .is_some_and(|ext| extensions.contains(&ext.to_lowercase().as_str()))
66 }
67
68 async fn extract(&self, path: &Path) -> Result<ExtractedContent, ExtractError> {
69 debug!("Extracting image: {:?}", path);
70
71 let bytes = tokio::fs::read(path).await?;
73
74 let (width, height, format) =
76 tokio::task::spawn_blocking(move || decode_image_metadata(&bytes))
77 .await
78 .map_err(|e| ExtractError::Failed(format!("Task join error: {e}")))?
79 .map_err(|e| ExtractError::Failed(format!("Image decode failed: {e}")))?;
80
81 let mime_type = mime_type_from_extension(path);
83
84 let data = tokio::fs::read(path).await?;
86
87 let text = format!(
89 "Image: {} ({}x{}, {})",
90 path.file_name()
91 .and_then(|n| n.to_str())
92 .unwrap_or("unknown"),
93 width,
94 height,
95 format
96 );
97
98 let caption = if let Some(ref captioner) = self.captioner {
100 if captioner.is_initialized().await {
101 match captioner.caption(&data).await {
102 Ok(cap) => cap,
103 Err(e) => {
104 warn!("Caption generation failed for {:?}: {}", path, e);
105 None
106 }
107 }
108 } else {
109 debug!("Captioner not initialized, skipping caption generation");
110 None
111 }
112 } else {
113 None
114 };
115
116 let extracted_image = ExtractedImage {
118 data,
119 mime_type: mime_type.clone(),
120 caption,
121 page: None,
122 };
123
124 Ok(ExtractedContent {
125 text,
126 elements: vec![ContentElement::Paragraph {
127 text: format!("{width}x{height} {format} image"),
128 byte_offset: 0,
129 }],
130 images: vec![extracted_image],
131 metadata: ContentMetadataInfo {
132 title: path
133 .file_name()
134 .and_then(|n| n.to_str())
135 .map(std::string::ToString::to_string),
136 ..Default::default()
137 },
138 })
139 }
140}
141
142fn decode_image_metadata(bytes: &[u8]) -> Result<(u32, u32, String), String> {
144 let img = image::load_from_memory(bytes).map_err(|e| format!("Failed to load image: {e}"))?;
145
146 let (width, height) = img.dimensions();
147
148 let format = image::guess_format(bytes).map_or_else(
150 |_| "unknown".to_string(),
151 |f| format!("{f:?}").to_lowercase(),
152 );
153
154 Ok((width, height, format))
155}
156
157fn mime_type_from_extension(path: &Path) -> String {
159 path.extension()
160 .and_then(|ext| ext.to_str())
161 .map_or("application/octet-stream", |ext| {
162 match ext.to_lowercase().as_str() {
163 "jpg" | "jpeg" => "image/jpeg",
164 "png" => "image/png",
165 "gif" => "image/gif",
166 "webp" => "image/webp",
167 "bmp" => "image/bmp",
168 "tiff" | "tif" => "image/tiff",
169 _ => "application/octet-stream",
170 }
171 })
172 .to_string()
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use tempfile::tempdir;
179
180 fn create_test_png() -> Vec<u8> {
182 use image::{ImageBuffer, Rgba};
183
184 let img: ImageBuffer<Rgba<u8>, Vec<u8>> = ImageBuffer::from_fn(2, 2, |x, y| {
185 if (x + y) % 2 == 0 {
186 Rgba([255, 0, 0, 255]) } else {
188 Rgba([0, 255, 0, 255]) }
190 });
191
192 let mut bytes: Vec<u8> = Vec::new();
193 let mut cursor = std::io::Cursor::new(&mut bytes);
194 img.write_to(&mut cursor, image::ImageFormat::Png).unwrap();
195 bytes
196 }
197
198 fn create_test_jpeg() -> Vec<u8> {
200 use image::{ImageBuffer, Rgb};
201
202 let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_fn(2, 2, |x, y| {
203 if (x + y) % 2 == 0 {
204 Rgb([255, 0, 0]) } else {
206 Rgb([0, 255, 0]) }
208 });
209
210 let mut bytes: Vec<u8> = Vec::new();
211 let mut cursor = std::io::Cursor::new(&mut bytes);
212 img.write_to(&mut cursor, image::ImageFormat::Jpeg).unwrap();
213 bytes
214 }
215
216 #[test]
217 fn test_new_extractor() {
218 let extractor = ImageExtractor::new();
219 assert!(!extractor.supported_types().is_empty());
220 }
221
222 #[test]
223 fn test_default_implementation() {
224 let extractor = ImageExtractor::default();
225 assert!(!extractor.supported_types().is_empty());
226 }
227
228 #[test]
229 fn test_supported_types_includes_common_formats() {
230 let extractor = ImageExtractor::new();
231 let types = extractor.supported_types();
232
233 assert!(types.contains(&"image/jpeg"));
234 assert!(types.contains(&"image/png"));
235 assert!(types.contains(&"image/gif"));
236 assert!(types.contains(&"image/webp"));
237 }
238
239 #[test]
240 fn test_can_extract_by_extension() {
241 let extractor = ImageExtractor::new();
242
243 assert!(extractor.can_extract_by_extension(Path::new("photo.jpg")));
244 assert!(extractor.can_extract_by_extension(Path::new("image.PNG")));
245 assert!(extractor.can_extract_by_extension(Path::new("animation.gif")));
246 assert!(extractor.can_extract_by_extension(Path::new("photo.webp")));
247 assert!(!extractor.can_extract_by_extension(Path::new("document.txt")));
248 assert!(!extractor.can_extract_by_extension(Path::new("code.rs")));
249 }
250
251 #[test]
252 fn test_can_extract_jpeg_variants() {
253 let extractor = ImageExtractor::new();
254
255 assert!(extractor.can_extract_by_extension(Path::new("photo.jpg")));
256 assert!(extractor.can_extract_by_extension(Path::new("photo.jpeg")));
257 assert!(extractor.can_extract_by_extension(Path::new("photo.JPG")));
258 assert!(extractor.can_extract_by_extension(Path::new("photo.JPEG")));
259 }
260
261 #[test]
262 fn test_can_extract_tiff_variants() {
263 let extractor = ImageExtractor::new();
264
265 assert!(extractor.can_extract_by_extension(Path::new("image.tiff")));
266 assert!(extractor.can_extract_by_extension(Path::new("image.tif")));
267 }
268
269 #[test]
270 fn test_cannot_extract_non_image() {
271 let extractor = ImageExtractor::new();
272
273 assert!(!extractor.can_extract_by_extension(Path::new("file.pdf")));
274 assert!(!extractor.can_extract_by_extension(Path::new("file.mp4")));
275 assert!(!extractor.can_extract_by_extension(Path::new("file.zip")));
276 }
277
278 #[test]
279 fn test_mime_type_from_extension() {
280 assert_eq!(
281 mime_type_from_extension(Path::new("test.jpg")),
282 "image/jpeg"
283 );
284 assert_eq!(mime_type_from_extension(Path::new("test.png")), "image/png");
285 assert_eq!(mime_type_from_extension(Path::new("test.gif")), "image/gif");
286 }
287
288 #[test]
289 fn test_mime_type_from_extension_case_insensitive() {
290 assert_eq!(
291 mime_type_from_extension(Path::new("test.JPG")),
292 "image/jpeg"
293 );
294 assert_eq!(mime_type_from_extension(Path::new("test.PNG")), "image/png");
295 }
296
297 #[test]
298 fn test_mime_type_unknown_extension() {
299 assert_eq!(
300 mime_type_from_extension(Path::new("test.xyz")),
301 "application/octet-stream"
302 );
303 }
304
305 #[tokio::test]
306 async fn test_extract_png_image() {
307 let temp_dir = tempdir().unwrap();
308 let file_path = temp_dir.path().join("test.png");
309 std::fs::write(&file_path, create_test_png()).unwrap();
310
311 let extractor = ImageExtractor::new();
312 let result = extractor.extract(&file_path).await;
313
314 assert!(result.is_ok());
315 let content = result.unwrap();
316 assert!(content.text.contains("test.png"));
317 assert!(content.text.contains("2x2")); assert_eq!(content.images.len(), 1);
319 assert_eq!(content.images[0].mime_type, "image/png");
320 }
321
322 #[tokio::test]
323 async fn test_extract_jpeg_image() {
324 let temp_dir = tempdir().unwrap();
325 let file_path = temp_dir.path().join("test.jpg");
326 std::fs::write(&file_path, create_test_jpeg()).unwrap();
327
328 let extractor = ImageExtractor::new();
329 let result = extractor.extract(&file_path).await;
330
331 assert!(result.is_ok());
332 let content = result.unwrap();
333 assert!(content.text.contains("test.jpg"));
334 assert_eq!(content.images.len(), 1);
335 assert_eq!(content.images[0].mime_type, "image/jpeg");
336 }
337
338 #[tokio::test]
339 async fn test_extract_creates_paragraph_element() {
340 let temp_dir = tempdir().unwrap();
341 let file_path = temp_dir.path().join("photo.png");
342 std::fs::write(&file_path, create_test_png()).unwrap();
343
344 let extractor = ImageExtractor::new();
345 let content = extractor.extract(&file_path).await.unwrap();
346
347 assert_eq!(content.elements.len(), 1);
348 match &content.elements[0] {
349 ragfs_core::ContentElement::Paragraph { text, .. } => {
350 assert!(text.contains("2x2"));
351 assert!(text.contains("image"));
352 }
353 _ => panic!("Expected Paragraph element"),
354 }
355 }
356
357 #[tokio::test]
358 async fn test_extract_stores_image_data() {
359 let temp_dir = tempdir().unwrap();
360 let file_path = temp_dir.path().join("image.png");
361 let png_data = create_test_png();
362 std::fs::write(&file_path, &png_data).unwrap();
363
364 let extractor = ImageExtractor::new();
365 let content = extractor.extract(&file_path).await.unwrap();
366
367 assert_eq!(content.images.len(), 1);
368 assert_eq!(content.images[0].data, png_data);
369 }
370
371 #[tokio::test]
372 async fn test_extract_sets_title_metadata() {
373 let temp_dir = tempdir().unwrap();
374 let file_path = temp_dir.path().join("my_photo.png");
375 std::fs::write(&file_path, create_test_png()).unwrap();
376
377 let extractor = ImageExtractor::new();
378 let content = extractor.extract(&file_path).await.unwrap();
379
380 assert_eq!(content.metadata.title, Some("my_photo.png".to_string()));
381 }
382
383 #[tokio::test]
384 async fn test_extract_nonexistent_file_fails() {
385 let extractor = ImageExtractor::new();
386 let result = extractor.extract(Path::new("/nonexistent/image.png")).await;
387
388 assert!(result.is_err());
389 }
390
391 #[tokio::test]
392 async fn test_extract_invalid_image_fails() {
393 let temp_dir = tempdir().unwrap();
394 let file_path = temp_dir.path().join("fake.png");
395 std::fs::write(&file_path, b"This is not an image").unwrap();
397
398 let extractor = ImageExtractor::new();
399 let result = extractor.extract(&file_path).await;
400
401 assert!(result.is_err());
402 }
403
404 #[test]
405 fn test_decode_image_metadata() {
406 let png_data = create_test_png();
407 let result = decode_image_metadata(&png_data);
408
409 assert!(result.is_ok());
410 let (width, height, format) = result.unwrap();
411 assert_eq!(width, 2);
412 assert_eq!(height, 2);
413 assert!(format.contains("png"));
414 }
415
416 #[test]
417 fn test_decode_image_metadata_invalid() {
418 let result = decode_image_metadata(b"not an image");
419 assert!(result.is_err());
420 }
421}