diff --git a/providers/base.py b/providers/base.py index 991642c..796c034 100644 --- a/providers/base.py +++ b/providers/base.py @@ -1,11 +1,16 @@ """Base model provider interface and data classes.""" +import base64 +import binascii import logging +import os from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, Optional +from utils.file_types import IMAGES, get_image_mime_type + logger = logging.getLogger(__name__) @@ -459,11 +464,6 @@ class ModelProvider(ABC): # Validate with custom size limit image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0) """ - import base64 - import os - - from utils.file_types import IMAGES, get_image_mime_type - # Use default if not specified if max_size_mb is None: max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB @@ -484,12 +484,18 @@ class ModelProvider(ABC): # Decode base64 data try: image_bytes = base64.b64decode(data) - except Exception as e: + except binascii.Error as e: raise ValueError(f"Invalid base64 data: {e}") else: # Handle file path - if not os.path.exists(image_path): + # Read file first to check if it exists + try: + with open(image_path, "rb") as f: + image_bytes = f.read() + except FileNotFoundError: raise ValueError(f"Image file not found: {image_path}") + except Exception as e: + raise ValueError(f"Failed to read image file: {e}") # Validate extension ext = os.path.splitext(image_path)[1].lower() @@ -499,13 +505,6 @@ class ModelProvider(ABC): # Get MIME type mime_type = get_image_mime_type(ext) - # Read file - try: - with open(image_path, "rb") as f: - image_bytes = f.read() - except Exception as e: - raise ValueError(f"Failed to read image file: {e}") - # Validate size size_mb = len(image_bytes) / (1024 * 1024) if size_mb > max_size_mb: