WIP lots of new tests and validation scenarios
Simulation tests to confirm threading and history traversal Chain of communication and branching validation tests from live simulation Temperature enforcement per model
This commit is contained in:
@@ -12,6 +12,90 @@ class ProviderType(Enum):
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Abstract base class for temperature constraints."""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Check if temperature is valid."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Get nearest valid temperature."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get human-readable description of constraint."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Get model's default temperature."""
|
||||
pass
|
||||
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models that only support one temperature value (e.g., O3)."""
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting continuous temperature ranges."""
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting only specific temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: List[float], default: float = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values)//2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Capabilities and constraints for a specific model."""
|
||||
@@ -23,7 +107,24 @@ class ModelCapabilities:
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
temperature_range: Tuple[float, float] = (0.0, 2.0)
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
)
|
||||
|
||||
# Backward compatibility property for existing code
|
||||
@property
|
||||
def temperature_range(self) -> Tuple[float, float]:
|
||||
"""Backward compatibility for existing code that uses temperature_range."""
|
||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
|
||||
return (self.temperature_constraint.value, self.temperature_constraint.value)
|
||||
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
|
||||
values = self.temperature_constraint.allowed_values
|
||||
return (min(values), max(values))
|
||||
return (0.0, 2.0) # Fallback
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,7 +5,13 @@ from typing import Dict, Optional, List
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||
from .base import (
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint
|
||||
)
|
||||
|
||||
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
@@ -58,6 +64,9 @@ class GeminiModelProvider(ModelProvider):
|
||||
|
||||
config = self.SUPPORTED_MODELS[resolved_name]
|
||||
|
||||
# Gemini models support 0.0-2.0 temperature range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.GOOGLE,
|
||||
model_name=resolved_name,
|
||||
@@ -67,7 +76,7 @@ class GeminiModelProvider(ModelProvider):
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_range=(0.0, 2.0),
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
|
||||
def generate_content(
|
||||
|
||||
@@ -6,7 +6,14 @@ import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from .base import ModelProvider, ModelResponse, ModelCapabilities, ProviderType
|
||||
from .base import (
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
FixedTemperatureConstraint,
|
||||
RangeTemperatureConstraint
|
||||
)
|
||||
|
||||
|
||||
class OpenAIModelProvider(ModelProvider):
|
||||
@@ -51,6 +58,14 @@ class OpenAIModelProvider(ModelProvider):
|
||||
|
||||
config = self.SUPPORTED_MODELS[model_name]
|
||||
|
||||
# Define temperature constraints per model
|
||||
if model_name in ["o3", "o3-mini"]:
|
||||
# O3 models only support temperature=1.0
|
||||
temp_constraint = FixedTemperatureConstraint(1.0)
|
||||
else:
|
||||
# Other OpenAI models support 0.0-2.0 range
|
||||
temp_constraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
model_name=model_name,
|
||||
@@ -60,7 +75,7 @@ class OpenAIModelProvider(ModelProvider):
|
||||
supports_system_prompts=True,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
temperature_range=(0.0, 2.0),
|
||||
temperature_constraint=temp_constraint,
|
||||
)
|
||||
|
||||
def generate_content(
|
||||
|
||||
Reference in New Issue
Block a user