refactor: code cleanup
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
"""Model provider abstractions for supporting multiple AI providers."""
|
||||
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
||||
from .base import ModelProvider
|
||||
from .gemini import GeminiModelProvider
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openai_provider import OpenAIModelProvider
|
||||
from .openrouter import OpenRouterProvider
|
||||
from .registry import ModelProviderRegistry
|
||||
from .shared import ModelCapabilities, ModelResponse
|
||||
|
||||
__all__ = [
|
||||
"ModelProvider",
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Base model provider interface and data classes."""
|
||||
"""Base interfaces and common behaviour for model providers."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,179 +12,20 @@ if TYPE_CHECKING:
|
||||
|
||||
from utils.file_types import IMAGES, get_image_mime_type
|
||||
|
||||
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Supported model provider types."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
DIAL = "dial"
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Abstract base class for temperature constraints."""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Check if temperature is valid."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Get nearest valid temperature."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get human-readable description of constraint."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Get model's default temperature."""
|
||||
pass
|
||||
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models that only support one temperature value (e.g., O3)."""
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting continuous temperature ranges."""
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""For models supporting only specific temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: list[float], default: float = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
||||
"""Create temperature constraint object from configuration string.
|
||||
|
||||
Args:
|
||||
constraint_type: Type of constraint ("fixed", "range", "discrete")
|
||||
|
||||
Returns:
|
||||
TemperatureConstraint object based on configuration
|
||||
"""
|
||||
if constraint_type == "fixed":
|
||||
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
elif constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
else:
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Capabilities and constraints for a specific model."""
|
||||
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
||||
context_window: int # Total context window size in tokens
|
||||
max_output_tokens: int # Maximum output tokens per request
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_images: bool = False # Whether model can process images
|
||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
||||
|
||||
# Additional fields for comprehensive model information
|
||||
description: str = "" # Human-readable description of the model
|
||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
||||
|
||||
# JSON mode support (for providers that support structured output)
|
||||
supports_json_mode: bool = False
|
||||
|
||||
# Thinking mode support (for models with thinking capabilities)
|
||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
||||
|
||||
# Custom model flag (for models that only work with custom endpoints)
|
||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
||||
|
||||
# Temperature constraint object - defines temperature limits and behavior
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from a model provider."""
|
||||
|
||||
content: str
|
||||
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
||||
model_name: str = ""
|
||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total tokens used."""
|
||||
return self.usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
"""Abstract base class for model providers."""
|
||||
"""Defines the contract implemented by every model provider backend.
|
||||
|
||||
Subclasses adapt third-party SDKs into the MCP server by exposing
|
||||
capability metadata, request execution, and token counting through a
|
||||
consistent interface. Shared helper methods (temperature validation,
|
||||
alias resolution, image handling, etc.) live here so individual providers
|
||||
only need to focus on provider-specific details.
|
||||
"""
|
||||
|
||||
# All concrete providers must define their supported models
|
||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||
|
||||
@@ -4,15 +4,15 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import (
|
||||
FixedTemperatureConstraint,
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
# Temperature inference patterns
|
||||
_TEMP_UNSUPPORTED_PATTERNS = [
|
||||
@@ -30,10 +30,13 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
|
||||
|
||||
|
||||
class CustomProvider(OpenAICompatibleProvider):
|
||||
"""Custom API provider for local models.
|
||||
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||
|
||||
Supports local inference servers like Ollama, vLLM, LM Studio,
|
||||
and any OpenAI-compatible API endpoint.
|
||||
The provider reuses the :mod:`providers.shared` registry to surface
|
||||
user-defined aliases and capability metadata. It also normalises
|
||||
Ollama-style version tags (``model:latest``) and enforces the same
|
||||
restriction policies used by cloud providers, ensuring consistent
|
||||
behaviour regardless of where the model is hosted.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "Custom API"
|
||||
|
||||
@@ -6,22 +6,24 @@ import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DIALModelProvider(OpenAICompatibleProvider):
|
||||
"""DIAL provider using OpenAI-compatible API.
|
||||
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
||||
|
||||
DIAL provides access to various AI models through a unified API interface.
|
||||
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
|
||||
DIAL exposes several third-party models behind a single OpenAI-compatible
|
||||
endpoint. This provider wraps the service, publishes capability metadata
|
||||
for the known deployments, and centralises retry/backoff settings tailored
|
||||
to DIAL's latency characteristics.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "DIAL"
|
||||
|
||||
@@ -11,13 +11,24 @@ if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
||||
from .base import ModelProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiModelProvider(ModelProvider):
|
||||
"""Google Gemini model provider implementation."""
|
||||
"""First-party Gemini integration built on the official Google SDK.
|
||||
|
||||
The provider advertises detailed thinking-mode budgets, handles optional
|
||||
custom endpoints, and performs image pre-processing before forwarding a
|
||||
request to the Gemini APIs.
|
||||
"""
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
|
||||
@@ -11,21 +11,21 @@ from urllib.parse import urlparse
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from .base import (
|
||||
from .base import ModelProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(ModelProvider):
|
||||
"""Base class for any provider using an OpenAI-compatible API.
|
||||
"""Shared implementation for OpenAI API lookalikes.
|
||||
|
||||
This includes:
|
||||
- Direct OpenAI API
|
||||
- OpenRouter
|
||||
- Any other OpenAI-compatible endpoint
|
||||
The class owns HTTP client configuration (timeouts, proxy hardening,
|
||||
custom headers) and normalises the OpenAI SDK responses into
|
||||
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
|
||||
provide capability metadata and any provider-specific request tweaks.
|
||||
"""
|
||||
|
||||
DEFAULT_HEADERS = {}
|
||||
|
||||
@@ -6,19 +6,24 @@ from typing import TYPE_CHECKING, Optional
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||
"""Official OpenAI API provider (api.openai.com)."""
|
||||
"""Implementation that talks to api.openai.com using rich model metadata.
|
||||
|
||||
In addition to the built-in catalogue, the provider can surface models
|
||||
defined in ``conf/custom_models.json`` (for organisations running their own
|
||||
OpenAI-compatible gateways) while still respecting restriction policies.
|
||||
"""
|
||||
|
||||
# Model configurations using ModelCapabilities objects
|
||||
SUPPORTED_MODELS = {
|
||||
|
||||
@@ -4,21 +4,22 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
RangeTemperatureConstraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||
"""OpenRouter unified API provider.
|
||||
"""Client for OpenRouter's multi-model aggregation service.
|
||||
|
||||
OpenRouter provides access to multiple AI models through a single API endpoint.
|
||||
See https://openrouter.ai for available models and pricing.
|
||||
OpenRouter surfaces dozens of upstream vendors. This provider layers alias
|
||||
resolution, restriction-aware filtering, and sensible capability defaults
|
||||
on top of the generic OpenAI-compatible plumbing.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "OpenRouter"
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
# Import handled via importlib.resources.files() calls directly
|
||||
from utils.file_utils import read_json_file
|
||||
|
||||
from .base import (
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ProviderType,
|
||||
create_temperature_constraint,
|
||||
@@ -17,7 +17,13 @@ from .base import (
|
||||
|
||||
|
||||
class OpenRouterModelRegistry:
|
||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
||||
"""Loads and validates the OpenRouter/custom model catalogue.
|
||||
|
||||
The registry parses ``conf/custom_models.json`` (or an override supplied via
|
||||
environment variable), builds case-insensitive alias maps, and exposes
|
||||
:class:`~providers.shared.ModelCapabilities` objects used by several
|
||||
providers.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize the registry.
|
||||
@@ -263,6 +269,11 @@ class OpenRouterModelRegistry:
|
||||
# Registry now returns ModelCapabilities directly
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||
"""Backward-compatible wrapper used by providers and older tests."""
|
||||
|
||||
return self.resolve(name_or_alias)
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""List all available model names."""
|
||||
return list(self.model_map.keys())
|
||||
|
||||
@@ -4,14 +4,20 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import ModelProvider, ProviderType
|
||||
from .base import ModelProvider
|
||||
from .shared import ProviderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
class ModelProviderRegistry:
|
||||
"""Registry for managing model providers."""
|
||||
"""Singleton that caches provider instances and coordinates priority order.
|
||||
|
||||
Responsibilities include resolving API keys from the environment, lazily
|
||||
instantiating providers, and choosing the best provider for a model based
|
||||
on restriction policies and provider priority.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
|
||||
23
providers/shared/__init__.py
Normal file
23
providers/shared/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared data structures and helpers for model providers."""
|
||||
|
||||
from .model_capabilities import ModelCapabilities
|
||||
from .model_response import ModelResponse
|
||||
from .provider_type import ProviderType
|
||||
from .temperature import (
|
||||
DiscreteTemperatureConstraint,
|
||||
FixedTemperatureConstraint,
|
||||
RangeTemperatureConstraint,
|
||||
TemperatureConstraint,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ModelCapabilities",
|
||||
"ModelResponse",
|
||||
"ProviderType",
|
||||
"TemperatureConstraint",
|
||||
"FixedTemperatureConstraint",
|
||||
"RangeTemperatureConstraint",
|
||||
"DiscreteTemperatureConstraint",
|
||||
"create_temperature_constraint",
|
||||
]
|
||||
34
providers/shared/model_capabilities.py
Normal file
34
providers/shared/model_capabilities.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Dataclass describing the feature set of a model exposed by a provider."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .provider_type import ProviderType
|
||||
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
|
||||
|
||||
__all__ = ["ModelCapabilities"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
"""Static capabilities and constraints for a provider-managed model."""
|
||||
|
||||
provider: ProviderType
|
||||
model_name: str
|
||||
friendly_name: str
|
||||
context_window: int
|
||||
max_output_tokens: int
|
||||
supports_extended_thinking: bool = False
|
||||
supports_system_prompts: bool = True
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = False
|
||||
supports_images: bool = False
|
||||
max_image_size_mb: float = 0.0
|
||||
supports_temperature: bool = True
|
||||
description: str = ""
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
supports_json_mode: bool = False
|
||||
max_thinking_tokens: int = 0
|
||||
is_custom: bool = False
|
||||
temperature_constraint: TemperatureConstraint = field(
|
||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
)
|
||||
26
providers/shared/model_response.py
Normal file
26
providers/shared/model_response.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Dataclass used to normalise provider SDK responses."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .provider_type import ProviderType
|
||||
|
||||
__all__ = ["ModelResponse"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Portable representation of a provider completion."""
|
||||
|
||||
content: str
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
model_name: str = ""
|
||||
friendly_name: str = ""
|
||||
provider: ProviderType = ProviderType.GOOGLE
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Return the total token count if the provider reported usage data."""
|
||||
|
||||
return self.usage.get("total_tokens", 0)
|
||||
16
providers/shared/provider_type.py
Normal file
16
providers/shared/provider_type.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Enumeration describing which backend owns a given model."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["ProviderType"]
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Canonical identifiers for every supported provider backend."""
|
||||
|
||||
GOOGLE = "google"
|
||||
OPENAI = "openai"
|
||||
XAI = "xai"
|
||||
OPENROUTER = "openrouter"
|
||||
CUSTOM = "custom"
|
||||
DIAL = "dial"
|
||||
121
providers/shared/temperature.py
Normal file
121
providers/shared/temperature.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Helper types for validating model temperature parameters."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"TemperatureConstraint",
|
||||
"FixedTemperatureConstraint",
|
||||
"RangeTemperatureConstraint",
|
||||
"DiscreteTemperatureConstraint",
|
||||
"create_temperature_constraint",
|
||||
]
|
||||
|
||||
|
||||
class TemperatureConstraint(ABC):
|
||||
"""Contract for temperature validation used by `ModelCapabilities`.
|
||||
|
||||
Concrete providers describe their temperature behaviour by creating
|
||||
subclasses that expose three operations:
|
||||
* `validate` – decide whether a requested temperature is acceptable.
|
||||
* `get_corrected_value` – coerce out-of-range values into a safe default.
|
||||
* `get_description` – provide a human readable error message for users.
|
||||
|
||||
Providers call these hooks before sending traffic to the underlying API so
|
||||
that unsupported temperatures never reach the remote service.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, temperature: float) -> bool:
|
||||
"""Return ``True`` when the temperature may be sent to the backend."""
|
||||
|
||||
@abstractmethod
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
"""Return a valid substitute for an out-of-range temperature."""
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Describe the acceptable range to include in error messages."""
|
||||
|
||||
@abstractmethod
|
||||
def get_default(self) -> float:
|
||||
"""Return the default temperature for the model."""
|
||||
|
||||
|
||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
||||
|
||||
def __init__(self, value: float):
|
||||
self.value = value
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return self.value
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Only supports temperature={self.value}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.value
|
||||
|
||||
|
||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for providers that expose a continuous min/max temperature range."""
|
||||
|
||||
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.default_temp = default or (min_temp + max_temp) / 2
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return self.min_temp <= temperature <= self.max_temp
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return max(self.min_temp, min(self.max_temp, temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||
"""Constraint for models that permit a discrete list of temperature values."""
|
||||
|
||||
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
|
||||
self.allowed_values = sorted(allowed_values)
|
||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||
|
||||
def validate(self, temperature: float) -> bool:
|
||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||
|
||||
def get_corrected_value(self, temperature: float) -> float:
|
||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||
|
||||
def get_description(self) -> str:
|
||||
return f"Supports temperatures: {self.allowed_values}"
|
||||
|
||||
def get_default(self) -> float:
|
||||
return self.default_temp
|
||||
|
||||
|
||||
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
||||
"""Factory that yields the appropriate constraint for a model configuration.
|
||||
|
||||
The JSON configuration stored in ``conf/custom_models.json`` references this
|
||||
helper via human-readable strings. Providers feed those values into this
|
||||
function so that runtime logic can rely on strongly typed constraint
|
||||
objects.
|
||||
"""
|
||||
|
||||
if constraint_type == "fixed":
|
||||
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||
return FixedTemperatureConstraint(1.0)
|
||||
if constraint_type == "discrete":
|
||||
# For models with specific allowed values - using common OpenAI values as default
|
||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||
# Default range constraint (for "range" or None)
|
||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||
@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Optional
|
||||
if TYPE_CHECKING:
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
from .base import (
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
from .shared import (
|
||||
ModelCapabilities,
|
||||
ModelResponse,
|
||||
ProviderType,
|
||||
create_temperature_constraint,
|
||||
)
|
||||
from .openai_compatible import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIModelProvider(OpenAICompatibleProvider):
|
||||
"""X.AI GROK API provider (api.x.ai)."""
|
||||
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
|
||||
|
||||
Publishes capability metadata for the officially supported deployments and
|
||||
maps tool-category preferences to the appropriate GROK model.
|
||||
"""
|
||||
|
||||
FRIENDLY_NAME = "X.AI"
|
||||
|
||||
|
||||
@@ -412,12 +412,12 @@ def configure_providers():
|
||||
value = os.getenv(key)
|
||||
logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}")
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.base import ProviderType
|
||||
from providers.custom import CustomProvider
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.shared import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
from utils.model_restrictions import get_restriction_service
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ if sys.platform == "win32":
|
||||
|
||||
# Register providers for all tests
|
||||
from providers import ModelProviderRegistry # noqa: E402
|
||||
from providers.base import ProviderType # noqa: E402
|
||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
||||
from providers.shared import ProviderType # noqa: E402
|
||||
from providers.xai import XAIModelProvider # noqa: E402
|
||||
|
||||
# Register providers at test startup
|
||||
@@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch):
|
||||
return
|
||||
|
||||
# Ensure providers are registered (in case other tests cleared the registry)
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
registry = ModelProviderRegistry()
|
||||
|
||||
@@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_model_restriction_env(monkeypatch):
|
||||
"""Ensure per-test isolation from user-defined model restriction env vars."""
|
||||
|
||||
restriction_vars = [
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
]
|
||||
|
||||
for var in restriction_vars:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
|
||||
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):
|
||||
|
||||
@@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.debug import DebugIssueTool
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
|
||||
@@ -4,8 +4,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
# Mock get_model_config to return our test model
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
@@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
# Mock get_model_config to return a model that supports temperature
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
@@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry = Mock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.base import ProviderType
|
||||
from providers.custom import CustomProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestCustomProvider:
|
||||
|
||||
@@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestDIALProvider:
|
||||
|
||||
@@ -8,7 +8,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
|
||||
from providers.base import ModelProvider
|
||||
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class MinimalTestProvider(ModelProvider):
|
||||
|
||||
@@ -9,8 +9,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestIntelligentFallback:
|
||||
|
||||
@@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored():
|
||||
mock_registry = Mock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
# This is what the user configured in their custom_models.json
|
||||
custom_config = ModelCapabilities(
|
||||
|
||||
@@ -5,8 +5,9 @@ import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ModelProvider, ProviderType
|
||||
from providers.base import ModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.listmodels import ListModelsTool
|
||||
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class TestModelEnumeration:
|
||||
|
||||
# Rebuild the provider registry with OpenRouter registered
|
||||
ModelProviderRegistry._instance = None
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ This test specifically targets the bug where:
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.shared import ProviderType
|
||||
from tools.consensus import ConsensusTool
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
|
||||
@@ -5,9 +5,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
from providers.shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
from providers.base import create_temperature_constraint
|
||||
from providers.shared import create_temperature_constraint
|
||||
|
||||
caps = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
|
||||
@@ -13,8 +13,8 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.chat import ChatTool
|
||||
from tools.shared.base_models import ToolRequest
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestProviderUTF8Encoding(unittest.TestCase):
|
||||
@@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
|
||||
|
||||
def test_model_response_utf8_serialization(self):
|
||||
"""Test UTF-8 serialization of model responses."""
|
||||
from providers.base import ModelResponse
|
||||
from providers.shared import ModelResponse
|
||||
|
||||
response = ModelResponse(
|
||||
content="Development successful! Code generated successfully. 🎉✅",
|
||||
|
||||
@@ -6,9 +6,9 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from providers import ModelProviderRegistry, ModelResponse
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestModelProviderRegistry:
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestSupportedModelsAliases:
|
||||
for provider in providers:
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# All values must be ModelCapabilities objects, not strings or dicts
|
||||
from providers.base import ModelCapabilities
|
||||
from providers.shared import ModelCapabilities
|
||||
|
||||
assert isinstance(config, ModelCapabilities), (
|
||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||
|
||||
@@ -10,8 +10,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.debug import DebugIssueTool
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class TestXAIProvider:
|
||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||
|
||||
# Check model configs have required fields
|
||||
from providers.base import ModelCapabilities
|
||||
from providers.shared import ModelCapabilities
|
||||
|
||||
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
||||
assert isinstance(grok4_config, ModelCapabilities)
|
||||
|
||||
@@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str):
|
||||
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
||||
"""
|
||||
# Ensure OpenAI provider is registered - always needed for transport injection
|
||||
from providers.base import ProviderType
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
# Always register OpenAI provider for transport tests (API key might be dummy)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
@@ -79,9 +79,9 @@ class ListModelsTool(BaseTool):
|
||||
Returns:
|
||||
Formatted list of models by provider
|
||||
"""
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
output_lines = ["# Available AI Models\n"]
|
||||
|
||||
@@ -162,8 +162,8 @@ class ListModelsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Get OpenRouter provider from registry to properly apply restrictions
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
if provider:
|
||||
|
||||
@@ -1341,7 +1341,7 @@ When recommending searches, be specific about what information you need and why
|
||||
# Apply 40MB cap for custom models if needed
|
||||
effective_limit_mb = max_size_mb
|
||||
try:
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
# ModelCapabilities dataclass has provider field defined
|
||||
if capabilities.provider == ProviderType.CUSTOM:
|
||||
|
||||
@@ -306,8 +306,8 @@ class VersionTool(BaseTool):
|
||||
|
||||
# Check for configured providers
|
||||
try:
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
provider_status = []
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user