Simulation tests to confirm threading and history traversal Chain of communication and branching validation tests from live simulation Temperature enforcement per model
223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
"""Base model provider interface and data classes."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from enum import Enum
|
|
|
|
|
|
class ProviderType(Enum):
|
|
"""Supported model provider types."""
|
|
GOOGLE = "google"
|
|
OPENAI = "openai"
|
|
|
|
|
|
class TemperatureConstraint(ABC):
|
|
"""Abstract base class for temperature constraints."""
|
|
|
|
@abstractmethod
|
|
def validate(self, temperature: float) -> bool:
|
|
"""Check if temperature is valid."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
"""Get nearest valid temperature."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_description(self) -> str:
|
|
"""Get human-readable description of constraint."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_default(self) -> float:
|
|
"""Get model's default temperature."""
|
|
pass
|
|
|
|
|
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
|
"""For models that only support one temperature value (e.g., O3)."""
|
|
|
|
def __init__(self, value: float):
|
|
self.value = value
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return self.value
|
|
|
|
def get_description(self) -> str:
|
|
return f"Only supports temperature={self.value}"
|
|
|
|
def get_default(self) -> float:
|
|
return self.value
|
|
|
|
|
|
class RangeTemperatureConstraint(TemperatureConstraint):
|
|
"""For models supporting continuous temperature ranges."""
|
|
|
|
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
|
self.min_temp = min_temp
|
|
self.max_temp = max_temp
|
|
self.default_temp = default or (min_temp + max_temp) / 2
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return self.min_temp <= temperature <= self.max_temp
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return max(self.min_temp, min(self.max_temp, temperature))
|
|
|
|
def get_description(self) -> str:
|
|
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
|
|
|
def get_default(self) -> float:
|
|
return self.default_temp
|
|
|
|
|
|
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
|
"""For models supporting only specific temperature values."""
|
|
|
|
def __init__(self, allowed_values: List[float], default: float = None):
|
|
self.allowed_values = sorted(allowed_values)
|
|
self.default_temp = default or allowed_values[len(allowed_values)//2]
|
|
|
|
def validate(self, temperature: float) -> bool:
|
|
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
|
|
|
def get_corrected_value(self, temperature: float) -> float:
|
|
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
|
|
|
def get_description(self) -> str:
|
|
return f"Supports temperatures: {self.allowed_values}"
|
|
|
|
def get_default(self) -> float:
|
|
return self.default_temp
|
|
|
|
|
|
@dataclass
|
|
class ModelCapabilities:
|
|
"""Capabilities and constraints for a specific model."""
|
|
provider: ProviderType
|
|
model_name: str
|
|
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
|
max_tokens: int
|
|
supports_extended_thinking: bool = False
|
|
supports_system_prompts: bool = True
|
|
supports_streaming: bool = True
|
|
supports_function_calling: bool = False
|
|
|
|
# Temperature constraint object - preferred way to define temperature limits
|
|
temperature_constraint: TemperatureConstraint = field(
|
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
)
|
|
|
|
# Backward compatibility property for existing code
|
|
@property
|
|
def temperature_range(self) -> Tuple[float, float]:
|
|
"""Backward compatibility for existing code that uses temperature_range."""
|
|
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
|
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
|
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
|
|
return (self.temperature_constraint.value, self.temperature_constraint.value)
|
|
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
|
|
values = self.temperature_constraint.allowed_values
|
|
return (min(values), max(values))
|
|
return (0.0, 2.0) # Fallback
|
|
|
|
|
|
@dataclass
|
|
class ModelResponse:
|
|
"""Response from a model provider."""
|
|
content: str
|
|
usage: Dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
|
model_name: str = ""
|
|
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
|
provider: ProviderType = ProviderType.GOOGLE
|
|
metadata: Dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Get total tokens used."""
|
|
return self.usage.get("total_tokens", 0)
|
|
|
|
|
|
class ModelProvider(ABC):
|
|
"""Abstract base class for model providers."""
|
|
|
|
def __init__(self, api_key: str, **kwargs):
|
|
"""Initialize the provider with API key and optional configuration."""
|
|
self.api_key = api_key
|
|
self.config = kwargs
|
|
|
|
@abstractmethod
|
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
|
"""Get capabilities for a specific model."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def generate_content(
|
|
self,
|
|
prompt: str,
|
|
model_name: str,
|
|
system_prompt: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
max_output_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> ModelResponse:
|
|
"""Generate content using the model.
|
|
|
|
Args:
|
|
prompt: User prompt to send to the model
|
|
model_name: Name of the model to use
|
|
system_prompt: Optional system prompt for model behavior
|
|
temperature: Sampling temperature (0-2)
|
|
max_output_tokens: Maximum tokens to generate
|
|
**kwargs: Provider-specific parameters
|
|
|
|
Returns:
|
|
ModelResponse with generated content and metadata
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def count_tokens(self, text: str, model_name: str) -> int:
|
|
"""Count tokens for the given text using the specified model's tokenizer."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_provider_type(self) -> ProviderType:
|
|
"""Get the provider type."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_model_name(self, model_name: str) -> bool:
|
|
"""Validate if the model name is supported by this provider."""
|
|
pass
|
|
|
|
def validate_parameters(
|
|
self,
|
|
model_name: str,
|
|
temperature: float,
|
|
**kwargs
|
|
) -> None:
|
|
"""Validate model parameters against capabilities.
|
|
|
|
Raises:
|
|
ValueError: If parameters are invalid
|
|
"""
|
|
capabilities = self.get_capabilities(model_name)
|
|
|
|
# Validate temperature
|
|
min_temp, max_temp = capabilities.temperature_range
|
|
if not min_temp <= temperature <= max_temp:
|
|
raise ValueError(
|
|
f"Temperature {temperature} out of range [{min_temp}, {max_temp}] "
|
|
f"for model {model_name}"
|
|
)
|
|
|
|
@abstractmethod
|
|
def supports_thinking_mode(self, model_name: str) -> bool:
|
|
"""Check if the model supports extended thinking mode."""
|
|
pass |