refactor: code cleanup
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
"""Base model provider interface and data classes."""
|
||||
"""Base interfaces and common behaviour for model providers."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,179 +12,20 @@ if TYPE_CHECKING:
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
DIAL = "dial"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Abstract base class for temperature constraints."""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Check if temperature is valid."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Get nearest valid temperature."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get human-readable description of constraint."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Get model's default temperature."""
|
||||
pass
|
||||
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models that only support one temperature value (e.g., O3)."""
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting continuous temperature ranges."""
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting only specific temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: list[float], default: float = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
||||
"""Create temperature constraint object from configuration string.
|
||||
|
||||
Args:
|
||||
constraint_type: Type of constraint ("fixed", "range", "discrete")
|
||||
|
||||
Returns:
|
||||
TemperatureConstraint object based on configuration
|
||||
"""
|
||||
if constraint_type == "fixed":
|
||||
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
elif constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
else:
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Capabilities and constraints for a specific model."""
|
||||
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||
context_window: int # Total context window size in tokens
|
||||
max_output_tokens: int # Maximum output tokens per request
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_images: bool = False # Whether model can process images
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
|
||||
# Additional fields for comprehensive model information
|
||||
description: str = "" # Human-readable description of the model
|
||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||
|
||||
# JSON mode support (for providers that support structured output)
|
||||
supports_json_mode: bool = False
|
||||
|
||||
# Thinking mode support (for models with thinking capabilities)
|
||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - defines temperature limits and behavior
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from a model provider."""
|
||||
|
||||
content: str
|
||||
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||
model_name: str = ""
|
||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total tokens used."""
|
||||
return self.usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
"""Abstract base class for model providers."""
|
||||
"""Defines the contract implemented by every model provider backend.
|
||||
|
||||
Subclasses adapt third-party SDKs into the MCP server by exposing
|
||||
capability metadata, request execution, and token counting through a
|
||||
consistent interface. Shared helper methods (temperature validation,
|
||||
alias resolution, image handling, etc.) live here so individual providers
|
||||
only need to focus on provider-specific details.
|
||||
"""
|
||||
|
||||
# All concrete providers must define their supported models
|
||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||
|
||||
Reference in New Issue
Block a user