Merge branch 'main' into refactor-image-validation
This commit is contained in:
@@ -7,7 +7,10 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
@@ -123,10 +126,10 @@ def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
elif constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.7)
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
else:
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -159,24 +162,11 @@ class ModelCapabilities:
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - preferred way to define temperature limits
|
||||
# Temperature constraint object - defines temperature limits and behavior
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
|
||||
# Backward compatibility property for existing code
|
||||
@property
|
||||
def temperature_range(self) -> tuple[float, float]:
|
||||
"""Backward compatibility for existing code that uses temperature_range."""
|
||||
if isinstance(self.temperature_constraint, RangeTemperatureConstraint):
|
||||
return (self.temperature_constraint.min_temp, self.temperature_constraint.max_temp)
|
||||
elif isinstance(self.temperature_constraint, FixedTemperatureConstraint):
|
||||
return (self.temperature_constraint.value, self.temperature_constraint.value)
|
||||
elif isinstance(self.temperature_constraint, DiscreteTemperatureConstraint):
|
||||
values = self.temperature_constraint.allowed_values
|
||||
return (min(values), max(values))
|
||||
return (0.0, 2.0) # Fallback
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
@@ -220,7 +210,7 @@ class ModelProvider(ABC):
|
||||
prompt: str,
|
||||
model_name: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = 0.3,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
@@ -276,18 +266,15 @@ class ModelProvider(ABC):
|
||||
if not capabilities.supports_temperature:
|
||||
return None
|
||||
|
||||
# Get temperature range
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
# Use temperature constraint to get corrected value
|
||||
corrected_temp = capabilities.temperature_constraint.get_corrected_value(requested_temperature)
|
||||
|
||||
# Clamp to valid range
|
||||
if requested_temperature < min_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {min_temp} for model {model_name}")
|
||||
return min_temp
|
||||
elif requested_temperature > max_temp:
|
||||
logger.debug(f"Clamping temperature from {requested_temperature} to {max_temp} for model {model_name}")
|
||||
return max_temp
|
||||
else:
|
||||
return requested_temperature
|
||||
if corrected_temp != requested_temperature:
|
||||
logger.debug(
|
||||
f"Adjusting temperature from {requested_temperature} to {corrected_temp} for model {model_name}"
|
||||
)
|
||||
|
||||
return corrected_temp
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine effective temperature for {model_name}: {e}")
|
||||
@@ -302,10 +289,10 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
capabilities = self.get_capabilities(model_name)
|
||||
|
||||
# Validate temperature
|
||||
min_temp, max_temp = capabilities.temperature_range
|
||||
if not min_temp <= temperature <= max_temp:
|
||||
raise ValueError(f"Temperature {temperature} out of range [{min_temp}, {max_temp}] for model {model_name}")
|
||||
# Validate temperature using constraint
|
||||
if not capabilities.temperature_constraint.validate(temperature):
|
||||
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||
|
||||
@abstractmethod
|
||||
def supports_thinking_mode(self, model_name: str) -> bool:
|
||||
@@ -520,3 +507,28 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
# Base implementation: no resources to clean up
|
||||
return
|
||||
|
||||
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||
"""Get the preferred model from this provider for a given category.
|
||||
|
||||
Args:
|
||||
category: The tool category requiring a model
|
||||
allowed_models: Pre-filtered list of model names that are allowed by restrictions
|
||||
|
||||
Returns:
|
||||
Model name if this provider has a preference, None otherwise
|
||||
"""
|
||||
# Default implementation - providers can override with specific logic
|
||||
return None
|
||||
|
||||
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||
"""Get the model registry for providers that maintain one.
|
||||
|
||||
This is a hook method for providers like CustomProvider that maintain
|
||||
a dynamic model registry.
|
||||
|
||||
Returns:
|
||||
Model registry dict or None if not applicable
|
||||
"""
|
||||
# Default implementation - most providers don't have a registry
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user