refactor: clean temperature inference
This commit is contained in:
@@ -6,27 +6,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
from .shared import (
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
FixedTemperatureConstraint,
|
|
||||||
ModelCapabilities,
|
|
||||||
ModelResponse,
|
|
||||||
ProviderType,
|
|
||||||
RangeTemperatureConstraint,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Temperature inference patterns
|
|
||||||
_TEMP_UNSUPPORTED_PATTERNS = [
|
|
||||||
"o1",
|
|
||||||
"o3",
|
|
||||||
"o4", # OpenAI O-series models
|
|
||||||
"deepseek-reasoner",
|
|
||||||
"deepseek-r1",
|
|
||||||
"r1", # DeepSeek reasoner models
|
|
||||||
]
|
|
||||||
|
|
||||||
_TEMP_UNSUPPORTED_KEYWORDS = [
|
|
||||||
"reasoner", # DeepSeek reasoner variants
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class CustomProvider(OpenAICompatibleProvider):
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
@@ -179,8 +159,10 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
"Consider adding to custom_models.json for specific capabilities."
|
"Consider adding to custom_models.json for specific capabilities."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Infer temperature support from model name for better defaults
|
# Infer temperature behaviour for generic capability fallback
|
||||||
supports_temperature, temperature_reason = self._infer_temperature_support(resolved_name)
|
supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings(
|
||||||
|
resolved_name
|
||||||
|
)
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. "
|
||||||
@@ -200,11 +182,7 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
supports_function_calling=False, # Conservative default
|
supports_function_calling=False, # Conservative default
|
||||||
supports_temperature=supports_temperature,
|
supports_temperature=supports_temperature,
|
||||||
temperature_constraint=(
|
temperature_constraint=temperature_constraint,
|
||||||
FixedTemperatureConstraint(1.0)
|
|
||||||
if not supports_temperature
|
|
||||||
else RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark as generic for validation purposes
|
# Mark as generic for validation purposes
|
||||||
@@ -212,36 +190,6 @@ class CustomProvider(OpenAICompatibleProvider):
|
|||||||
|
|
||||||
return capabilities
|
return capabilities
|
||||||
|
|
||||||
def _infer_temperature_support(self, model_name: str) -> tuple[bool, str]:
|
|
||||||
"""Infer temperature support from model name patterns.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (supports_temperature, reason_for_decision)
|
|
||||||
"""
|
|
||||||
model_lower = model_name.lower()
|
|
||||||
|
|
||||||
# Check for specific model patterns that don't support temperature
|
|
||||||
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
|
|
||||||
conditions = (
|
|
||||||
pattern == model_lower,
|
|
||||||
model_lower.startswith(f"{pattern}-"),
|
|
||||||
model_lower.startswith(f"openai/{pattern}"),
|
|
||||||
model_lower.startswith(f"deepseek/{pattern}"),
|
|
||||||
model_lower.endswith(f"-{pattern}"),
|
|
||||||
f"/{pattern}" in model_lower,
|
|
||||||
f"-{pattern}-" in model_lower,
|
|
||||||
)
|
|
||||||
if any(conditions):
|
|
||||||
return False, f"detected non-temperature-supporting model pattern '{pattern}'"
|
|
||||||
|
|
||||||
# Check for specific keywords that indicate non-supporting variants
|
|
||||||
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
|
|
||||||
if keyword in model_lower:
|
|
||||||
return False, f"detected non-temperature-supporting keyword '{keyword}'"
|
|
||||||
|
|
||||||
# Default to supporting temperature for most models
|
|
||||||
return True, "default assumption for unknown custom models"
|
|
||||||
|
|
||||||
def get_provider_type(self) -> ProviderType:
|
def get_provider_type(self) -> ProviderType:
|
||||||
"""Get the provider type."""
|
"""Get the provider type."""
|
||||||
return ProviderType.CUSTOM
|
return ProviderType.CUSTOM
|
||||||
|
|||||||
@@ -11,6 +11,21 @@ __all__ = [
|
|||||||
"create_temperature_constraint",
|
"create_temperature_constraint",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Common heuristics for determining temperature support when explicit
|
||||||
|
# capabilities are unavailable (e.g., custom/local models).
|
||||||
|
_TEMP_UNSUPPORTED_PATTERNS = {
|
||||||
|
"o1",
|
||||||
|
"o3",
|
||||||
|
"o4", # OpenAI O-series reasoning models
|
||||||
|
"deepseek-reasoner",
|
||||||
|
"deepseek-r1",
|
||||||
|
"r1", # DeepSeek reasoner variants
|
||||||
|
}
|
||||||
|
|
||||||
|
_TEMP_UNSUPPORTED_KEYWORDS = {
|
||||||
|
"reasoner", # Catch additional DeepSeek-style naming patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TemperatureConstraint(ABC):
|
class TemperatureConstraint(ABC):
|
||||||
"""Contract for temperature validation used by `ModelCapabilities`.
|
"""Contract for temperature validation used by `ModelCapabilities`.
|
||||||
@@ -41,6 +56,65 @@ class TemperatureConstraint(ABC):
|
|||||||
def get_default(self) -> float:
|
def get_default(self) -> float:
|
||||||
"""Return the default temperature for the model."""
|
"""Return the default temperature for the model."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_support(model_name: str) -> tuple[bool, str]:
|
||||||
|
"""Heuristically determine whether a model supports temperature."""
|
||||||
|
|
||||||
|
model_lower = model_name.lower()
|
||||||
|
|
||||||
|
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
|
||||||
|
conditions = (
|
||||||
|
pattern == model_lower,
|
||||||
|
model_lower.startswith(f"{pattern}-"),
|
||||||
|
model_lower.startswith(f"openai/{pattern}"),
|
||||||
|
model_lower.startswith(f"deepseek/{pattern}"),
|
||||||
|
model_lower.endswith(f"-{pattern}"),
|
||||||
|
f"/{pattern}" in model_lower,
|
||||||
|
f"-{pattern}-" in model_lower,
|
||||||
|
)
|
||||||
|
if any(conditions):
|
||||||
|
return False, f"detected pattern '{pattern}'"
|
||||||
|
|
||||||
|
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
|
||||||
|
if keyword in model_lower:
|
||||||
|
return False, f"detected keyword '{keyword}'"
|
||||||
|
|
||||||
|
return True, "default assumption for models without explicit metadata"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_settings(
|
||||||
|
model_name: str,
|
||||||
|
constraint_hint: Optional[str] = None,
|
||||||
|
) -> tuple[bool, "TemperatureConstraint", str]:
|
||||||
|
"""Derive temperature support and constraint for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Canonical model identifier or alias.
|
||||||
|
constraint_hint: Optional configuration hint (``"fixed"``,
|
||||||
|
``"range"``, ``"discrete"``). When provided, the hint is
|
||||||
|
honoured directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple ``(supports_temperature, constraint, diagnosis)`` describing
|
||||||
|
whether temperature may be tuned, the constraint object that should
|
||||||
|
be attached to :class:`ModelCapabilities`, and the reasoning behind
|
||||||
|
the decision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if constraint_hint:
|
||||||
|
constraint = create_temperature_constraint(constraint_hint)
|
||||||
|
supports_temperature = constraint_hint != "fixed"
|
||||||
|
reason = f"constraint hint '{constraint_hint}'"
|
||||||
|
return supports_temperature, constraint, reason
|
||||||
|
|
||||||
|
supports_temperature, reason = TemperatureConstraint.infer_support(model_name)
|
||||||
|
if supports_temperature:
|
||||||
|
constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
|
else:
|
||||||
|
constraint = FixedTemperatureConstraint(1.0)
|
||||||
|
|
||||||
|
return supports_temperature, constraint, reason
|
||||||
|
|
||||||
|
|
||||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||||
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
||||||
|
|||||||
Reference in New Issue
Block a user