diff --git a/providers/custom.py b/providers/custom.py index a4bad33..2a96f3a 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -6,27 +6,7 @@ from typing import Optional from .openai_compatible import OpenAICompatibleProvider from .openrouter_registry import OpenRouterModelRegistry -from .shared import ( - FixedTemperatureConstraint, - ModelCapabilities, - ModelResponse, - ProviderType, - RangeTemperatureConstraint, -) - -# Temperature inference patterns -_TEMP_UNSUPPORTED_PATTERNS = [ - "o1", - "o3", - "o4", # OpenAI O-series models - "deepseek-reasoner", - "deepseek-r1", - "r1", # DeepSeek reasoner models -] - -_TEMP_UNSUPPORTED_KEYWORDS = [ - "reasoner", # DeepSeek reasoner variants -] +from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint class CustomProvider(OpenAICompatibleProvider): @@ -179,8 +159,10 @@ class CustomProvider(OpenAICompatibleProvider): "Consider adding to custom_models.json for specific capabilities." ) - # Infer temperature support from model name for better defaults - supports_temperature, temperature_reason = self._infer_temperature_support(resolved_name) + # Infer temperature behaviour for generic capability fallback + supports_temperature, temperature_constraint, temperature_reason = TemperatureConstraint.resolve_settings( + resolved_name + ) logging.warning( f"Model '{resolved_name}' not found in custom_models.json. Using generic capabilities with inferred settings. " @@ -200,11 +182,7 @@ class CustomProvider(OpenAICompatibleProvider): supports_streaming=True, supports_function_calling=False, # Conservative default supports_temperature=supports_temperature, - temperature_constraint=( - FixedTemperatureConstraint(1.0) - if not supports_temperature - else RangeTemperatureConstraint(0.0, 2.0, 0.7) - ), + temperature_constraint=temperature_constraint, ) # Mark as generic for validation purposes @@ -212,36 +190,6 @@ class CustomProvider(OpenAICompatibleProvider): return capabilities - def _infer_temperature_support(self, model_name: str) -> tuple[bool, str]: - """Infer temperature support from model name patterns. - - Returns: - Tuple of (supports_temperature, reason_for_decision) - """ - model_lower = model_name.lower() - - # Check for specific model patterns that don't support temperature - for pattern in _TEMP_UNSUPPORTED_PATTERNS: - conditions = ( - pattern == model_lower, - model_lower.startswith(f"{pattern}-"), - model_lower.startswith(f"openai/{pattern}"), - model_lower.startswith(f"deepseek/{pattern}"), - model_lower.endswith(f"-{pattern}"), - f"/{pattern}" in model_lower, - f"-{pattern}-" in model_lower, - ) - if any(conditions): - return False, f"detected non-temperature-supporting model pattern '{pattern}'" - - # Check for specific keywords that indicate non-supporting variants - for keyword in _TEMP_UNSUPPORTED_KEYWORDS: - if keyword in model_lower: - return False, f"detected non-temperature-supporting keyword '{keyword}'" - - # Default to supporting temperature for most models - return True, "default assumption for unknown custom models" - def get_provider_type(self) -> ProviderType: """Get the provider type.""" return ProviderType.CUSTOM diff --git a/providers/shared/temperature.py b/providers/shared/temperature.py index 6c6c9af..ec0adec 100644 --- a/providers/shared/temperature.py +++ b/providers/shared/temperature.py @@ -11,6 +11,21 @@ __all__ = [ "create_temperature_constraint", ] +# Common heuristics for determining temperature support when explicit +# capabilities are unavailable (e.g., custom/local models). +_TEMP_UNSUPPORTED_PATTERNS = { + "o1", + "o3", + "o4", # OpenAI O-series reasoning models + "deepseek-reasoner", + "deepseek-r1", + "r1", # DeepSeek reasoner variants +} + +_TEMP_UNSUPPORTED_KEYWORDS = { + "reasoner", # Catch additional DeepSeek-style naming patterns +} + class TemperatureConstraint(ABC): """Contract for temperature validation used by `ModelCapabilities`. @@ -41,6 +56,65 @@ class TemperatureConstraint(ABC): def get_default(self) -> float: """Return the default temperature for the model.""" + @staticmethod + def infer_support(model_name: str) -> tuple[bool, str]: + """Heuristically determine whether a model supports temperature.""" + + model_lower = model_name.lower() + + for pattern in _TEMP_UNSUPPORTED_PATTERNS: + conditions = ( + pattern == model_lower, + model_lower.startswith(f"{pattern}-"), + model_lower.startswith(f"openai/{pattern}"), + model_lower.startswith(f"deepseek/{pattern}"), + model_lower.endswith(f"-{pattern}"), + f"/{pattern}" in model_lower, + f"-{pattern}-" in model_lower, + ) + if any(conditions): + return False, f"detected pattern '{pattern}'" + + for keyword in _TEMP_UNSUPPORTED_KEYWORDS: + if keyword in model_lower: + return False, f"detected keyword '{keyword}'" + + return True, "default assumption for models without explicit metadata" + + @staticmethod + def resolve_settings( + model_name: str, + constraint_hint: Optional[str] = None, + ) -> tuple[bool, "TemperatureConstraint", str]: + """Derive temperature support and constraint for a model. + + Args: + model_name: Canonical model identifier or alias. + constraint_hint: Optional configuration hint (``"fixed"``, + ``"range"``, ``"discrete"``). When provided, the hint is + honoured directly. + + Returns: + Tuple ``(supports_temperature, constraint, diagnosis)`` describing + whether temperature may be tuned, the constraint object that should + be attached to :class:`ModelCapabilities`, and the reasoning behind + the decision. + """ + + if constraint_hint: + constraint = create_temperature_constraint(constraint_hint) + supports_temperature = constraint_hint != "fixed" + reason = f"constraint hint '{constraint_hint}'" + return supports_temperature, constraint, reason + + supports_temperature, reason = TemperatureConstraint.infer_support(model_name) + if supports_temperature: + constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) + else: + constraint = FixedTemperatureConstraint(1.0) + + return supports_temperature, constraint, reason + class FixedTemperatureConstraint(TemperatureConstraint): """Constraint for models that enforce an exact temperature (for example O3)."""