Address PR #192 review comments
- Fix TOCTOU race condition by removing os.path.exists() check before file open - Move imports (base64, binascii, os, utils.file_types) to top of file - Replace broad Exception catch with specific binascii.Error for base64 decoding - Maintain proper error handling and test compatibility
This commit is contained in:
@@ -1,11 +1,16 @@
|
|||||||
"""Base model provider interface and data classes."""
|
"""Base model provider interface and data classes."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from utils.file_types import IMAGES, get_image_mime_type
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -459,11 +464,6 @@ class ModelProvider(ABC):
|
|||||||
# Validate with custom size limit
|
# Validate with custom size limit
|
||||||
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
||||||
"""
|
"""
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
|
|
||||||
from utils.file_types import IMAGES, get_image_mime_type
|
|
||||||
|
|
||||||
# Use default if not specified
|
# Use default if not specified
|
||||||
if max_size_mb is None:
|
if max_size_mb is None:
|
||||||
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
||||||
@@ -484,12 +484,18 @@ class ModelProvider(ABC):
|
|||||||
# Decode base64 data
|
# Decode base64 data
|
||||||
try:
|
try:
|
||||||
image_bytes = base64.b64decode(data)
|
image_bytes = base64.b64decode(data)
|
||||||
except Exception as e:
|
except binascii.Error as e:
|
||||||
raise ValueError(f"Invalid base64 data: {e}")
|
raise ValueError(f"Invalid base64 data: {e}")
|
||||||
else:
|
else:
|
||||||
# Handle file path
|
# Handle file path
|
||||||
if not os.path.exists(image_path):
|
# Read file first to check if it exists
|
||||||
|
try:
|
||||||
|
with open(image_path, "rb") as f:
|
||||||
|
image_bytes = f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
raise ValueError(f"Image file not found: {image_path}")
|
raise ValueError(f"Image file not found: {image_path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to read image file: {e}")
|
||||||
|
|
||||||
# Validate extension
|
# Validate extension
|
||||||
ext = os.path.splitext(image_path)[1].lower()
|
ext = os.path.splitext(image_path)[1].lower()
|
||||||
@@ -499,13 +505,6 @@ class ModelProvider(ABC):
|
|||||||
# Get MIME type
|
# Get MIME type
|
||||||
mime_type = get_image_mime_type(ext)
|
mime_type = get_image_mime_type(ext)
|
||||||
|
|
||||||
# Read file
|
|
||||||
try:
|
|
||||||
with open(image_path, "rb") as f:
|
|
||||||
image_bytes = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to read image file: {e}")
|
|
||||||
|
|
||||||
# Validate size
|
# Validate size
|
||||||
size_mb = len(image_bytes) / (1024 * 1024)
|
size_mb = len(image_bytes) / (1024 * 1024)
|
||||||
if size_mb > max_size_mb:
|
if size_mb > max_size_mb:
|
||||||
|
|||||||
Reference in New Issue
Block a user