- Add missing base64 import in providers/base.py - Remove unused base64 import from providers/openai_compatible.py - All tests now pass (19/19 image validation tests) - Code quality checks pass 100%
535 lines
19 KiB
Python
535 lines
19 KiB
Python
"""Base model provider interface and data classes."""
|
|
|
|
import base64
|
|
import binascii
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
if TYPE_CHECKING:
|
|
from tools.models import ToolModelCategory
|
|
|
|
from utils.file_types import IMAGES, get_image_mime_type
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProviderType(Enum):
|
|
"""Supported model provider types."""
|
|
|
|
GOOGLE = "google"
|
|
OPENAI = "openai"
|
|
XAI = "xai"
|
|
OPENROUTER = "openrouter"
|
|
CUSTOM = "custom"
|
|
DIAL = "dial"
|
|
|
|
|
|
class TemperatureConstraint(ABC):
|
|
"""Abstract base class for temperature constraints."""
|
|
|
|
@abstractmethod
|
|
def validate(self, temperature: float) -> bool:
|
|
"""Check if temperature is valid."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
"""Get nearest valid temperature."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_description(self) -> str:
|
|
"""Get human-readable description of constraint."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_default(self) -> float:
|
|
"""Get model's default temperature."""
|
|
pass
|
|
|
|
|
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
|
"""For models that only support one temperature value (e.g., O3)."""
|
|
|
|
def __init__(self, value: float):
|
|
self.value = value
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return self.value
|
|
|
|
def get_description(self) -> str:
|
|
return f"Only supports temperature={self.value}"
|
|
|
|
def get_default(self) -> float:
|
|
return self.value
|
|
|
|
|
|
class RangeTemperatureConstraint(TemperatureConstraint):
|
|
"""For models supporting continuous temperature ranges."""
|
|
|
|
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
|
self.min_temp = min_temp
|
|
self.max_temp = max_temp
|
|
self.default_temp = default or (min_temp + max_temp) / 2
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return self.min_temp <= temperature <= self.max_temp
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return max(self.min_temp, min(self.max_temp, temperature))
|
|
|
|
def get_description(self) -> str:
|
|
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
|
|
|
def get_default(self) -> float:
|
|
return self.default_temp
|
|
|
|
|
|
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
|
"""For models supporting only specific temperature values."""
|
|
|
|
def __init__(self, allowed_values: list[float], default: float = None):
|
|
self.allowed_values = sorted(allowed_values)
|
|
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
|
|
|
def get_description(self) -> str:
|
|
return f"Supports temperatures: {self.allowed_values}"
|
|
|
|
def get_default(self) -> float:
|
|
return self.default_temp
|
|
|
|
|
|
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
|
"""Create temperature constraint object from configuration string.
|
|
|
|
Args:
|
|
constraint_type: Type of constraint ("fixed", "range", "discrete")
|
|
|
|
Returns:
|
|
TemperatureConstraint object based on configuration
|
|
"""
|
|
if constraint_type == "fixed":
|
|
# Fixed temperature models (O3/O4) only support temperature=1.0
|
|
return FixedTemperatureConstraint(1.0)
|
|
elif constraint_type == "discrete":
|
|
# For models with specific allowed values - using common OpenAI values as default
|
|
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
|
else:
|
|
# Default range constraint (for "range" or None)
|
|
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
|
|
|
|
|
@dataclass
|
|
class ModelCapabilities:
|
|
"""Capabilities and constraints for a specific model."""
|
|
|
|
provider: ProviderType
|
|
model_name: str
|
|
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
|
context_window: int # Total context window size in tokens
|
|
max_output_tokens: int # Maximum output tokens per request
|
|
supports_extended_thinking: bool = False
|
|
supports_system_prompts: bool = True
|
|
supports_streaming: bool = True
|
|
supports_function_calling: bool = False
|
|
supports_images: bool = False # Whether model can process images
|
|
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
|
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
|
|
|
# Additional fields for comprehensive model information
|
|
description: str = "" # Human-readable description of the model
|
|
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
|
|
|
# JSON mode support (for providers that support structured output)
|
|
supports_json_mode: bool = False
|
|
|
|
# Thinking mode support (for models with thinking capabilities)
|
|
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
|
|
|
# Custom model flag (for models that only work with custom endpoints)
|
|
is_custom: bool = False # Whether this model requires custom API endpoints
|
|
|
|
# Temperature constraint object - defines temperature limits and behavior
|
|
temperature_constraint: TemperatureConstraint = field(
|
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ModelResponse:
|
|
"""Response from a model provider."""
|
|
|
|
content: str
|
|
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
|
model_name: str = ""
|
|
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
|
provider: ProviderType = ProviderType.GOOGLE
|
|
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Get total tokens used."""
|
|
return self.usage.get("total_tokens", 0)
|
|
|
|
|
|
class ModelProvider(ABC):
|
|
"""Abstract base class for model providers."""
|
|
|
|
# All concrete providers must define their supported models
|
|
SUPPORTED_MODELS: dict[str, Any] = {}
|
|
|
|
# Default maximum image size in MB
|
|
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
"""Initialize the provider with API key and optional configuration."""
|
|
self.api_key = api_key
|
|
self.config = kwargs
|
|
|
|
@abstractmethod
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
"""Get capabilities for a specific model."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.3,
|
|
max_output_tokens: Optional[int] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""Generate content using the model.
|
|
|
|
Args:
|
|
prompt: User prompt to send to the model
|
|
model_name: Name of the model to use
|
|
system_prompt: Optional system prompt for model behavior
|
|
temperature: Sampling temperature (0-2)
|
|
max_output_tokens: Maximum tokens to generate
|
|
**kwargs: Provider-specific parameters
|
|
|
|
Returns:
|
|
ModelResponse with generated content and metadata
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def count_tokens(self, text: str, model_name: str) -> int:
|
|
"""Count tokens for the given text using the specified model's tokenizer."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
"""Validate if the model name is supported by this provider."""
|
|
pass
|
|
|
|
def get_effective_temperature(self, model_name: str, requested_temperature: float) -> Optional[float]:
|
|
"""Get the effective temperature to use for a model given a requested temperature.
|
|
|
|
This method handles:
|
|
- Models that don't support temperature (returns None)
|
|
- Fixed temperature models (returns the fixed value)
|
|
- Clamping to min/max range for models with constraints
|
|
|
|
Args:
|
|
model_name: The model to get temperature for
|
|
requested_temperature: The temperature requested by the user/tool
|
|
|
|
Returns:
|
|
The effective temperature to use, or None if temperature shouldn't be passed
|
|
"""
|
|
try:
|
|
capabilities = self.get_capabilities(model_name)
|
|
|
|
# Check if model supports temperature at all
|
|
if not capabilities.supports_temperature:
|
|
return None
|
|
|
|
# Use temperature constraint to get corrected value
|
|
corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
|
|
|
|
if corrected_temp != requested_temperature:
|
|
logger.debug(
|
|
f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
|
|
)
|
|
|
|
return corrected_temp
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
|
|
# If we can't get capabilities, return the requested temperature
|
|
return requested_temperature
|
|
|
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
|
"""Validate model parameters against capabilities.
|
|
|
|
Raises:
|
|
ValueError: If parameters are invalid
|
|
"""
|
|
capabilities = self.get_capabilities(model_name)
|
|
|
|
# Validate temperature using constraint
|
|
if not capabilities.temperature_constraint.validate(temperature):
|
|
constraint_desc = capabilities.temperature_constraint.get_description()
|
|
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
|
|
|
@abstractmethod
|
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
|
"""Check if the model supports extended thinking mode."""
|
|
pass
|
|
|
|
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
|
|
"""Get model configurations for this provider.
|
|
|
|
This is a hook method that subclasses can override to provide
|
|
their model configurations from different sources.
|
|
|
|
Returns:
|
|
Dictionary mapping model names to their ModelCapabilities objects
|
|
"""
|
|
# Return SUPPORTED_MODELS if it exists (must contain ModelCapabilities objects)
|
|
if hasattr(self, "SUPPORTED_MODELS"):
|
|
return {k: v for k, v in self.SUPPORTED_MODELS.items() if isinstance(v, ModelCapabilities)}
|
|
return {}
|
|
|
|
def get_all_model_aliases(self) -> dict[str, list[str]]:
|
|
"""Get all model aliases for this provider.
|
|
|
|
This is a hook method that subclasses can override to provide
|
|
aliases from different sources.
|
|
|
|
Returns:
|
|
Dictionary mapping model names to their list of aliases
|
|
"""
|
|
# Default implementation extracts from ModelCapabilities objects
|
|
aliases = {}
|
|
for model_name, capabilities in self.get_model_configurations().items():
|
|
if capabilities.aliases:
|
|
aliases[model_name] = capabilities.aliases
|
|
return aliases
|
|
|
|
def _resolve_model_name(self, model_name: str) -> str:
|
|
"""Resolve model shorthand to full name.
|
|
|
|
This implementation uses the hook methods to support different
|
|
model configuration sources.
|
|
|
|
Args:
|
|
model_name: Model name that may be an alias
|
|
|
|
Returns:
|
|
Resolved model name
|
|
"""
|
|
# Get model configurations from the hook method
|
|
model_configs = self.get_model_configurations()
|
|
|
|
# First check if it's already a base model name (case-sensitive exact match)
|
|
if model_name in model_configs:
|
|
return model_name
|
|
|
|
# Check case-insensitively for both base models and aliases
|
|
model_name_lower = model_name.lower()
|
|
|
|
# Check base model names case-insensitively
|
|
for base_model in model_configs:
|
|
if base_model.lower() == model_name_lower:
|
|
return base_model
|
|
|
|
# Check aliases from the hook method
|
|
all_aliases = self.get_all_model_aliases()
|
|
for base_model, aliases in all_aliases.items():
|
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
|
return base_model
|
|
|
|
# If not found, return as-is
|
|
return model_name
|
|
|
|
def list_models(self, respect_restrictions: bool = True) -> list[str]:
|
|
"""Return a list of model names supported by this provider.
|
|
|
|
This implementation uses the get_model_configurations() hook
|
|
to support different model configuration sources.
|
|
|
|
Args:
|
|
respect_restrictions: Whether to apply provider-specific restriction logic.
|
|
|
|
Returns:
|
|
List of model names available from this provider
|
|
"""
|
|
from utils.model_restrictions import get_restriction_service
|
|
|
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
|
models = []
|
|
|
|
# Get model configurations from the hook method
|
|
model_configs = self.get_model_configurations()
|
|
|
|
for model_name in model_configs:
|
|
# Check restrictions if enabled
|
|
if restriction_service and not restriction_service.is_allowed(self.get_provider_type(), model_name):
|
|
continue
|
|
|
|
# Add the base model
|
|
models.append(model_name)
|
|
|
|
# Get aliases from the hook method
|
|
all_aliases = self.get_all_model_aliases()
|
|
for model_name, aliases in all_aliases.items():
|
|
# Only add aliases for models that passed restriction check
|
|
if model_name in models:
|
|
models.extend(aliases)
|
|
|
|
return models
|
|
|
|
def list_all_known_models(self) -> list[str]:
|
|
"""Return all model names known by this provider, including alias targets.
|
|
|
|
This is used for validation purposes to ensure restriction policies
|
|
can validate against both aliases and their target model names.
|
|
|
|
Returns:
|
|
List of all model names and alias targets known by this provider
|
|
"""
|
|
all_models = set()
|
|
|
|
# Get model configurations from the hook method
|
|
model_configs = self.get_model_configurations()
|
|
|
|
# Add all base model names
|
|
for model_name in model_configs:
|
|
all_models.add(model_name.lower())
|
|
|
|
# Get aliases from the hook method and add them
|
|
all_aliases = self.get_all_model_aliases()
|
|
for _model_name, aliases in all_aliases.items():
|
|
for alias in aliases:
|
|
all_models.add(alias.lower())
|
|
|
|
return list(all_models)
|
|
|
|
def validate_image(self, image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
|
"""Provider-independent image validation.
|
|
|
|
Args:
|
|
image_path: Path to image file or data URL
|
|
max_size_mb: Maximum allowed image size in MB (defaults to DEFAULT_MAX_IMAGE_SIZE_MB)
|
|
|
|
Returns:
|
|
Tuple of (image_bytes, mime_type)
|
|
|
|
Raises:
|
|
ValueError: If image is invalid
|
|
|
|
Examples:
|
|
# Validate a file path
|
|
image_bytes, mime_type = provider.validate_image("/path/to/image.png")
|
|
|
|
# Validate a data URL
|
|
image_bytes, mime_type = provider.validate_image("data:image/png;base64,...")
|
|
|
|
# Validate with custom size limit
|
|
image_bytes, mime_type = provider.validate_image("/path/to/image.jpg", max_size_mb=10.0)
|
|
"""
|
|
# Use default if not specified
|
|
if max_size_mb is None:
|
|
max_size_mb = self.DEFAULT_MAX_IMAGE_SIZE_MB
|
|
|
|
if image_path.startswith("data:"):
|
|
# Parse data URL: ...
|
|
try:
|
|
header, data = image_path.split(",", 1)
|
|
mime_type = header.split(";")[0].split(":")[1]
|
|
except (ValueError, IndexError) as e:
|
|
raise ValueError(f"Invalid data URL format: {e}")
|
|
|
|
# Validate MIME type using IMAGES constant
|
|
valid_mime_types = [get_image_mime_type(ext) for ext in IMAGES]
|
|
if mime_type not in valid_mime_types:
|
|
raise ValueError(f"Unsupported image type: {mime_type}. Supported types: {', '.join(valid_mime_types)}")
|
|
|
|
# Decode base64 data
|
|
try:
|
|
image_bytes = base64.b64decode(data)
|
|
except binascii.Error as e:
|
|
raise ValueError(f"Invalid base64 data: {e}")
|
|
else:
|
|
# Handle file path
|
|
# Read file first to check if it exists
|
|
try:
|
|
with open(image_path, "rb") as f:
|
|
image_bytes = f.read()
|
|
except FileNotFoundError:
|
|
raise ValueError(f"Image file not found: {image_path}")
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to read image file: {e}")
|
|
|
|
# Validate extension
|
|
ext = os.path.splitext(image_path)[1].lower()
|
|
if ext not in IMAGES:
|
|
raise ValueError(f"Unsupported image format: {ext}. Supported formats: {', '.join(sorted(IMAGES))}")
|
|
|
|
# Get MIME type
|
|
mime_type = get_image_mime_type(ext)
|
|
|
|
# Validate size
|
|
size_mb = len(image_bytes) / (1024 * 1024)
|
|
if size_mb > max_size_mb:
|
|
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
|
|
|
return image_bytes, mime_type
|
|
|
|
def close(self):
|
|
"""Clean up any resources held by the provider.
|
|
|
|
Default implementation does nothing.
|
|
Subclasses should override if they hold resources that need cleanup.
|
|
"""
|
|
# Base implementation: no resources to clean up
|
|
return
|
|
|
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
|
"""Get the preferred model from this provider for a given category.
|
|
|
|
Args:
|
|
category: The tool category requiring a model
|
|
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
|
|
|
Returns:
|
|
Model name if this provider has a preference, None otherwise
|
|
"""
|
|
# Default implementation - providers can override with specific logic
|
|
return None
|
|
|
|
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
|
"""Get the model registry for providers that maintain one.
|
|
|
|
This is a hook method for providers like CustomProvider that maintain
|
|
a dynamic model registry.
|
|
|
|
Returns:
|
|
Model registry dict or None if not applicable
|
|
"""
|
|
# Default implementation - most providers don't have a registry
|
|
return None
|