refactor: clean temperature inference

This commit is contained in:
Fahad
2025-10-02 10:41:05 +04:00
parent 6d237d0970
commit 9c11ecc4bf
2 changed files with 80 additions and 58 deletions

View File

@@ -6,27 +6,7 @@ from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
FixedTemperatureConstraint,
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
# Temperature inference patterns
_TEMP_UNSUPPORTED_PATTERNS = [
"o1",
"o3",
"o4", # OpenAI O-series models
"deepseek-reasoner",
"deepseek-r1",
"r1", # DeepSeek reasoner models
]
_TEMP_UNSUPPORTED_KEYWORDS = [
"reasoner", # DeepSeek reasoner variants
]
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
class CustomProvider(OpenAICompatibleProvider):
@@ -179,8 +159,10 @@ class CustomProvider(OpenAICompatibleProvider):
"Consider adding to custom_models.json for specific capabilities."
)
# Infer temperature support from model name for better defaults
supports_temperature, temperature_reason = self._infer_temperature_support(resolved_name)
# Infer temperature behaviour for generic capability fallback
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
resolved_name
)
logging.warning(
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
@@ -200,11 +182,7 @@ class CustomProvider(OpenAICompatibleProvider):
supports_streaming=True,
supports_function_calling=False, # Conservative default
supports_temperature=supports_temperature,
temperature_constraint=(
FixedTemperatureConstraint(1.0)
if not supports_temperature
else RangeTemperatureConstraint(0.0, 2.0, 0.7)
),
temperature_constraint=temperature_constraint,
)
# Mark as generic for validation purposes
@@ -212,36 +190,6 @@ class CustomProvider(OpenAICompatibleProvider):
return capabilities
def _infer_temperature_support(self, model_name: str) -> tuple[bool, str]:
"""Infer temperature support from model name patterns.
Returns:
Tuple of (supports_temperature, reason_for_decision)
"""
model_lower = model_name.lower()
# Check for specific model patterns that don't support temperature
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
conditions = (
pattern == model_lower,
model_lower.startswith(f"{pattern}-"),
model_lower.startswith(f"openai/{pattern}"),
model_lower.startswith(f"deepseek/{pattern}"),
model_lower.endswith(f"-{pattern}"),
f"/{pattern}" in model_lower,
f"-{pattern}-" in model_lower,
)
if any(conditions):
return False, f"detected non-temperature-supporting model pattern '{pattern}'"
# Check for specific keywords that indicate non-supporting variants
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
if keyword in model_lower:
return False, f"detected non-temperature-supporting keyword '{keyword}'"
# Default to supporting temperature for most models
return True, "default assumption for unknown custom models"
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.CUSTOM

View File

@@ -11,6 +11,21 @@ __all__ = [
"create_temperature_constraint",
]
# Common heuristics for determining temperature support when explicit
# capabilities are unavailable (e.g., custom/local models).
_TEMP_UNSUPPORTED_PATTERNS = {
"o1",
"o3",
"o4", # OpenAI O-series reasoning models
"deepseek-reasoner",
"deepseek-r1",
"r1", # DeepSeek reasoner variants
}
_TEMP_UNSUPPORTED_KEYWORDS = {
"reasoner", # Catch additional DeepSeek-style naming patterns
}
class TemperatureConstraint(ABC):
"""Contract for temperature validation used by `ModelCapabilities`.
@@ -41,6 +56,65 @@ class TemperatureConstraint(ABC):
def get_default(self) -> float:
"""Return the default temperature for the model."""
@staticmethod
def infer_support(model_name: str) -> tuple[bool, str]:
"""Heuristically determine whether a model supports temperature."""
model_lower = model_name.lower()
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
conditions = (
pattern == model_lower,
model_lower.startswith(f"{pattern}-"),
model_lower.startswith(f"openai/{pattern}"),
model_lower.startswith(f"deepseek/{pattern}"),
model_lower.endswith(f"-{pattern}"),
f"/{pattern}" in model_lower,
f"-{pattern}-" in model_lower,
)
if any(conditions):
return False, f"detected pattern '{pattern}'"
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
if keyword in model_lower:
return False, f"detected keyword '{keyword}'"
return True, "default assumption for models without explicit metadata"
@staticmethod
def resolve_settings(
model_name: str,
constraint_hint: Optional[str] = None,
) -> tuple[bool, "TemperatureConstraint", str]:
"""Derive temperature support and constraint for a model.
Args:
model_name: Canonical model identifier or alias.
constraint_hint: Optional configuration hint (``"fixed"``,
``"range"``, ``"discrete"``). When provided, the hint is
honoured directly.
Returns:
Tuple ``(supports_temperature, constraint, diagnosis)`` describing
whether temperature may be tuned, the constraint object that should
be attached to :class:`ModelCapabilities`, and the reasoning behind
the decision.
"""
if constraint_hint:
constraint = create_temperature_constraint(constraint_hint)
supports_temperature = constraint_hint != "fixed"
reason = f"constraint hint '{constraint_hint}'"
return supports_temperature, constraint, reason
supports_temperature, reason = TemperatureConstraint.infer_support(model_name)
if supports_temperature:
constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
else:
constraint = FixedTemperatureConstraint(1.0)
return supports_temperature, constraint, reason
class FixedTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that enforce an exact temperature (for example O3)."""