refactor: code cleanup
This commit is contained in:
23
providers/shared/__init__.py
Normal file
23
providers/shared/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared data structures and helpers for model providers."""
|
||||
|
||||
from .model_capabilities import ModelCapabilities
|
||||
from .model_response import ModelResponse
|
||||
from .provider_type import ProviderType
|
||||
from .temperature import (
|
||||
DiscreteTemperatureConstraint,
|
||||
FixedTemperatureConstraint,
|
||||
RangeTemperatureConstraint,
|
||||
TemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ModelCapabilities",
|
||||
"ModelResponse",
|
||||
"ProviderType",
|
||||
"TemperatureConstraint",
|
||||
"FixedTemperatureConstraint",
|
||||
"RangeTemperatureConstraint",
|
||||
"DiscreteTemperatureConstraint",
|
||||
"create_temperature_constraint",
|
||||
]
|
||||
34
providers/shared/model_capabilities.py
Normal file
34
providers/shared/model_capabilities.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Dataclass describing the feature set of a model exposed by a provider."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .provider_type import ProviderType
|
||||
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
|
||||
|
||||
__all__ = ["ModelCapabilities"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Static capabilities and constraints for a provider-managed model."""
|
||||
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str
|
||||
context_window: int
|
||||
max_output_tokens: int
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_images: bool = False
|
||||
max_image_size_mb: float = 0.0
|
||||
supports_temperature: bool = True
|
||||
description: str = ""
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
supports_json_mode: bool = False
|
||||
max_thinking_tokens: int = 0
|
||||
is_custom: bool = False
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
26
providers/shared/model_response.py
Normal file
26
providers/shared/model_response.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Dataclass used to normalise provider SDK responses."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .provider_type import ProviderType
|
||||
|
||||
__all__ = ["ModelResponse"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Portable representation of a provider completion."""
|
||||
|
||||
content: str
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
model_name: str = ""
|
||||
friendly_name: str = ""
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Return the total token count if the provider reported usage data."""
|
||||
|
||||
return self.usage.get("total_tokens", 0)
|
||||
16
providers/shared/provider_type.py
Normal file
16
providers/shared/provider_type.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Enumeration describing which backend owns a given model."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["ProviderType"]
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Canonical identifiers for every supported provider backend."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
DIAL = "dial"
|
||||
121
providers/shared/temperature.py
Normal file
121
providers/shared/temperature.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Helper types for validating model temperature parameters."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"TemperatureConstraint",
|
||||
"FixedTemperatureConstraint",
|
||||
"RangeTemperatureConstraint",
|
||||
"DiscreteTemperatureConstraint",
|
||||
"create_temperature_constraint",
|
||||
]
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Contract for temperature validation used by `ModelCapabilities`.
|
||||
|
||||
Concrete providers describe their temperature behaviour by creating
|
||||
subclasses that expose three operations:
|
||||
* `validate` – decide whether a requested temperature is acceptable.
|
||||
* `get_corrected_value` – coerce out-of-range values into a safe default.
|
||||
* `get_description` – provide a human readable error message for users.
|
||||
|
||||
Providers call these hooks before sending traffic to the underlying API so
|
||||
that unsupported temperatures never reach the remote service.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Return ``True`` when the temperature may be sent to the backend."""
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Return a valid substitute for an out-of-range temperature."""
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Describe the acceptable range to include in error messages."""
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Return the default temperature for the model."""
|
||||
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for providers that expose a continuous min/max temperature range."""
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for models that permit a discrete list of temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
||||
"""Factory that yields the appropriate constraint for a model configuration.
|
||||
|
||||
The JSON configuration stored in ``conf/custom_models.json`` references this
|
||||
helper via human-readable strings. Providers feed those values into this
|
||||
function so that runtime logic can rely on strongly typed constraint
|
||||
objects.
|
||||
"""
|
||||
|
||||
if constraint_type == "fixed":
|
||||
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
if constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
Reference in New Issue
Block a user