refactor: code cleanup
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
"""Model provider abstractions for supporting multiple AI providers."""
|
"""Model provider abstractions for supporting multiple AI providers."""
|
||||||
|
|
||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse
|
from .base import ModelProvider
|
||||||
from .gemini import GeminiModelProvider
|
from .gemini import GeminiModelProvider
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
from .openai_provider import OpenAIModelProvider
|
from .openai_provider import OpenAIModelProvider
|
||||||
from .openrouter import OpenRouterProvider
|
from .openrouter import OpenRouterProvider
|
||||||
from .registry import ModelProviderRegistry
|
from .registry import ModelProviderRegistry
|
||||||
|
from .shared import ModelCapabilities, ModelResponse
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"""Base model provider interface and data classes."""
|
"""Base interfaces and common behaviour for model providers."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -14,179 +12,20 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from utils.file_types import IMAGES, get_image_mime_type
|
from utils.file_types import IMAGES, get_image_mime_type
|
||||||
|
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProviderType(Enum):
|
|
||||||
"""Supported model provider types."""
|
|
||||||
|
|
||||||
GOOGLE = "google"
|
|
||||||
OPENAI = "openai"
|
|
||||||
XAI = "xai"
|
|
||||||
OPENROUTER = "openrouter"
|
|
||||||
CUSTOM = "custom"
|
|
||||||
DIAL = "dial"
|
|
||||||
|
|
||||||
|
|
||||||
class TemperatureConstraint(ABC):
|
|
||||||
"""Abstract base class for temperature constraints."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate(self, temperature: float) -> bool:
|
|
||||||
"""Check if temperature is valid."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
|
||||||
"""Get nearest valid temperature."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str:
|
|
||||||
"""Get human-readable description of constraint."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_default(self) -> float:
|
|
||||||
"""Get model's default temperature."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FixedTemperatureConstraint(TemperatureConstraint):
|
|
||||||
"""For models that only support one temperature value (e.g., O3)."""
|
|
||||||
|
|
||||||
def __init__(self, value: float):
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
|
||||||
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return f"Only supports temperature={self.value}"
|
|
||||||
|
|
||||||
def get_default(self) -> float:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class RangeTemperatureConstraint(TemperatureConstraint):
|
|
||||||
"""For models supporting continuous temperature ranges."""
|
|
||||||
|
|
||||||
def __init__(self, min_temp: float, max_temp: float, default: float = None):
|
|
||||||
self.min_temp = min_temp
|
|
||||||
self.max_temp = max_temp
|
|
||||||
self.default_temp = default or (min_temp + max_temp) / 2
|
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
|
||||||
return self.min_temp <= temperature <= self.max_temp
|
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
|
||||||
return max(self.min_temp, min(self.max_temp, temperature))
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
|
||||||
|
|
||||||
def get_default(self) -> float:
|
|
||||||
return self.default_temp
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
|
||||||
"""For models supporting only specific temperature values."""
|
|
||||||
|
|
||||||
def __init__(self, allowed_values: list[float], default: float = None):
|
|
||||||
self.allowed_values = sorted(allowed_values)
|
|
||||||
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
|
||||||
|
|
||||||
def validate(self, temperature: float) -> bool:
|
|
||||||
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
|
||||||
|
|
||||||
def get_corrected_value(self, temperature: float) -> float:
|
|
||||||
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return f"Supports temperatures: {self.allowed_values}"
|
|
||||||
|
|
||||||
def get_default(self) -> float:
|
|
||||||
return self.default_temp
|
|
||||||
|
|
||||||
|
|
||||||
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
|
||||||
"""Create temperature constraint object from configuration string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
constraint_type: Type of constraint ("fixed", "range", "discrete")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TemperatureConstraint object based on configuration
|
|
||||||
"""
|
|
||||||
if constraint_type == "fixed":
|
|
||||||
# Fixed temperature models (O3/O4) only support temperature=1.0
|
|
||||||
return FixedTemperatureConstraint(1.0)
|
|
||||||
elif constraint_type == "discrete":
|
|
||||||
# For models with specific allowed values - using common OpenAI values as default
|
|
||||||
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
|
||||||
else:
|
|
||||||
# Default range constraint (for "range" or None)
|
|
||||||
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelCapabilities:
|
|
||||||
"""Capabilities and constraints for a specific model."""
|
|
||||||
|
|
||||||
provider: ProviderType
|
|
||||||
model_name: str
|
|
||||||
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
|
|
||||||
context_window: int # Total context window size in tokens
|
|
||||||
max_output_tokens: int # Maximum output tokens per request
|
|
||||||
supports_extended_thinking: bool = False
|
|
||||||
supports_system_prompts: bool = True
|
|
||||||
supports_streaming: bool = True
|
|
||||||
supports_function_calling: bool = False
|
|
||||||
supports_images: bool = False # Whether model can process images
|
|
||||||
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
|
|
||||||
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
|
|
||||||
|
|
||||||
# Additional fields for comprehensive model information
|
|
||||||
description: str = "" # Human-readable description of the model
|
|
||||||
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
|
|
||||||
|
|
||||||
# JSON mode support (for providers that support structured output)
|
|
||||||
supports_json_mode: bool = False
|
|
||||||
|
|
||||||
# Thinking mode support (for models with thinking capabilities)
|
|
||||||
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
|
|
||||||
|
|
||||||
# Custom model flag (for models that only work with custom endpoints)
|
|
||||||
is_custom: bool = False # Whether this model requires custom API endpoints
|
|
||||||
|
|
||||||
# Temperature constraint object - defines temperature limits and behavior
|
|
||||||
temperature_constraint: TemperatureConstraint = field(
|
|
||||||
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelResponse:
|
|
||||||
"""Response from a model provider."""
|
|
||||||
|
|
||||||
content: str
|
|
||||||
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
|
|
||||||
model_name: str = ""
|
|
||||||
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
|
|
||||||
provider: ProviderType = ProviderType.GOOGLE
|
|
||||||
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total_tokens(self) -> int:
|
|
||||||
"""Get total tokens used."""
|
|
||||||
return self.usage.get("total_tokens", 0)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
"""Abstract base class for model providers."""
|
"""Defines the contract implemented by every model provider backend.
|
||||||
|
|
||||||
|
Subclasses adapt third-party SDKs into the MCP server by exposing
|
||||||
|
capability metadata, request execution, and token counting through a
|
||||||
|
consistent interface. Shared helper methods (temperature validation,
|
||||||
|
alias resolution, image handling, etc.) live here so individual providers
|
||||||
|
only need to focus on provider-specific details.
|
||||||
|
"""
|
||||||
|
|
||||||
# All concrete providers must define their supported models
|
# All concrete providers must define their supported models
|
||||||
SUPPORTED_MODELS: dict[str, Any] = {}
|
SUPPORTED_MODELS: dict[str, Any] = {}
|
||||||
|
|||||||
@@ -4,15 +4,15 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base import (
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from .shared import (
|
||||||
FixedTemperatureConstraint,
|
FixedTemperatureConstraint,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
|
||||||
|
|
||||||
# Temperature inference patterns
|
# Temperature inference patterns
|
||||||
_TEMP_UNSUPPORTED_PATTERNS = [
|
_TEMP_UNSUPPORTED_PATTERNS = [
|
||||||
@@ -30,10 +30,13 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
|
|||||||
|
|
||||||
|
|
||||||
class CustomProvider(OpenAICompatibleProvider):
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
"""Custom API provider for local models.
|
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||||
|
|
||||||
Supports local inference servers like Ollama, vLLM, LM Studio,
|
The provider reuses the :mod:`providers.shared` registry to surface
|
||||||
and any OpenAI-compatible API endpoint.
|
user-defined aliases and capability metadata. It also normalises
|
||||||
|
Ollama-style version tags (``model:latest``) and enforces the same
|
||||||
|
restriction policies used by cloud providers, ensuring consistent
|
||||||
|
behaviour regardless of where the model is hosted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "Custom API"
|
FRIENDLY_NAME = "Custom API"
|
||||||
|
|||||||
@@ -6,22 +6,24 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base import (
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
create_temperature_constraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DIALModelProvider(OpenAICompatibleProvider):
|
class DIALModelProvider(OpenAICompatibleProvider):
|
||||||
"""DIAL provider using OpenAI-compatible API.
|
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
||||||
|
|
||||||
DIAL provides access to various AI models through a unified API interface.
|
DIAL exposes several third-party models behind a single OpenAI-compatible
|
||||||
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
|
endpoint. This provider wraps the service, publishes capability metadata
|
||||||
|
for the known deployments, and centralises retry/backoff settings tailored
|
||||||
|
to DIAL's latency characteristics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "DIAL"
|
FRIENDLY_NAME = "DIAL"
|
||||||
|
|||||||
@@ -11,13 +11,24 @@ if TYPE_CHECKING:
|
|||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
|
from .base import ModelProvider
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
create_temperature_constraint,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GeminiModelProvider(ModelProvider):
|
class GeminiModelProvider(ModelProvider):
|
||||||
"""Google Gemini model provider implementation."""
|
"""First-party Gemini integration built on the official Google SDK.
|
||||||
|
|
||||||
|
The provider advertises detailed thinking-mode budgets, handles optional
|
||||||
|
custom endpoints, and performs image pre-processing before forwarding a
|
||||||
|
request to the Gemini APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
|
|||||||
@@ -11,21 +11,21 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from .base import (
|
from .base import ModelProvider
|
||||||
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelProvider,
|
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatibleProvider(ModelProvider):
|
class OpenAICompatibleProvider(ModelProvider):
|
||||||
"""Base class for any provider using an OpenAI-compatible API.
|
"""Shared implementation for OpenAI API lookalikes.
|
||||||
|
|
||||||
This includes:
|
The class owns HTTP client configuration (timeouts, proxy hardening,
|
||||||
- Direct OpenAI API
|
custom headers) and normalises the OpenAI SDK responses into
|
||||||
- OpenRouter
|
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
|
||||||
- Any other OpenAI-compatible endpoint
|
provide capability metadata and any provider-specific request tweaks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_HEADERS = {}
|
DEFAULT_HEADERS = {}
|
||||||
|
|||||||
@@ -6,19 +6,24 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from .base import (
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
create_temperature_constraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(OpenAICompatibleProvider):
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""Official OpenAI API provider (api.openai.com)."""
|
"""Implementation that talks to api.openai.com using rich model metadata.
|
||||||
|
|
||||||
|
In addition to the built-in catalogue, the provider can surface models
|
||||||
|
defined in ``conf/custom_models.json`` (for organisations running their own
|
||||||
|
OpenAI-compatible gateways) while still respecting restriction policies.
|
||||||
|
"""
|
||||||
|
|
||||||
# Model configurations using ModelCapabilities objects
|
# Model configurations using ModelCapabilities objects
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
|
|||||||
@@ -4,21 +4,22 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base import (
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
RangeTemperatureConstraint,
|
RangeTemperatureConstraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
|
||||||
from .openrouter_registry import OpenRouterModelRegistry
|
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterProvider(OpenAICompatibleProvider):
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
"""OpenRouter unified API provider.
|
"""Client for OpenRouter's multi-model aggregation service.
|
||||||
|
|
||||||
OpenRouter provides access to multiple AI models through a single API endpoint.
|
OpenRouter surfaces dozens of upstream vendors. This provider layers alias
|
||||||
See https://openrouter.ai for available models and pricing.
|
resolution, restriction-aware filtering, and sensible capability defaults
|
||||||
|
on top of the generic OpenAI-compatible plumbing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "OpenRouter"
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
# Import handled via importlib.resources.files() calls directly
|
# Import handled via importlib.resources.files() calls directly
|
||||||
from utils.file_utils import read_json_file
|
from utils.file_utils import read_json_file
|
||||||
|
|
||||||
from .base import (
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
create_temperature_constraint,
|
create_temperature_constraint,
|
||||||
@@ -17,7 +17,13 @@ from .base import (
|
|||||||
|
|
||||||
|
|
||||||
class OpenRouterModelRegistry:
|
class OpenRouterModelRegistry:
|
||||||
"""Registry for managing OpenRouter model configurations and aliases."""
|
"""Loads and validates the OpenRouter/custom model catalogue.
|
||||||
|
|
||||||
|
The registry parses ``conf/custom_models.json`` (or an override supplied via
|
||||||
|
environment variable), builds case-insensitive alias maps, and exposes
|
||||||
|
:class:`~providers.shared.ModelCapabilities` objects used by several
|
||||||
|
providers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
"""Initialize the registry.
|
"""Initialize the registry.
|
||||||
@@ -263,6 +269,11 @@ class OpenRouterModelRegistry:
|
|||||||
# Registry now returns ModelCapabilities directly
|
# Registry now returns ModelCapabilities directly
|
||||||
return self.resolve(name_or_alias)
|
return self.resolve(name_or_alias)
|
||||||
|
|
||||||
|
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
|
"""Backward-compatible wrapper used by providers and older tests."""
|
||||||
|
|
||||||
|
return self.resolve(name_or_alias)
|
||||||
|
|
||||||
def list_models(self) -> list[str]:
|
def list_models(self) -> list[str]:
|
||||||
"""List all available model names."""
|
"""List all available model names."""
|
||||||
return list(self.model_map.keys())
|
return list(self.model_map.keys())
|
||||||
|
|||||||
@@ -4,14 +4,20 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from .base import ModelProvider, ProviderType
|
from .base import ModelProvider
|
||||||
|
from .shared import ProviderType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderRegistry:
|
class ModelProviderRegistry:
|
||||||
"""Registry for managing model providers."""
|
"""Singleton that caches provider instances and coordinates priority order.
|
||||||
|
|
||||||
|
Responsibilities include resolving API keys from the environment, lazily
|
||||||
|
instantiating providers, and choosing the best provider for a model based
|
||||||
|
on restriction policies and provider priority.
|
||||||
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
|
|||||||
23
providers/shared/__init__.py
Normal file
23
providers/shared/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Shared data structures and helpers for model providers."""
|
||||||
|
|
||||||
|
from .model_capabilities import ModelCapabilities
|
||||||
|
from .model_response import ModelResponse
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
from .temperature import (
|
||||||
|
DiscreteTemperatureConstraint,
|
||||||
|
FixedTemperatureConstraint,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
TemperatureConstraint,
|
||||||
|
create_temperature_constraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelCapabilities",
|
||||||
|
"ModelResponse",
|
||||||
|
"ProviderType",
|
||||||
|
"TemperatureConstraint",
|
||||||
|
"FixedTemperatureConstraint",
|
||||||
|
"RangeTemperatureConstraint",
|
||||||
|
"DiscreteTemperatureConstraint",
|
||||||
|
"create_temperature_constraint",
|
||||||
|
]
|
||||||
34
providers/shared/model_capabilities.py
Normal file
34
providers/shared/model_capabilities.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Dataclass describing the feature set of a model exposed by a provider."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
|
||||||
|
|
||||||
|
__all__ = ["ModelCapabilities"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapabilities:
|
||||||
|
"""Static capabilities and constraints for a provider-managed model."""
|
||||||
|
|
||||||
|
provider: ProviderType
|
||||||
|
model_name: str
|
||||||
|
friendly_name: str
|
||||||
|
context_window: int
|
||||||
|
max_output_tokens: int
|
||||||
|
supports_extended_thinking: bool = False
|
||||||
|
supports_system_prompts: bool = True
|
||||||
|
supports_streaming: bool = True
|
||||||
|
supports_function_calling: bool = False
|
||||||
|
supports_images: bool = False
|
||||||
|
max_image_size_mb: float = 0.0
|
||||||
|
supports_temperature: bool = True
|
||||||
|
description: str = ""
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
supports_json_mode: bool = False
|
||||||
|
max_thinking_tokens: int = 0
|
||||||
|
is_custom: bool = False
|
||||||
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
|
)
|
||||||
26
providers/shared/model_response.py
Normal file
26
providers/shared/model_response.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Dataclass used to normalise provider SDK responses."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
|
||||||
|
__all__ = ["ModelResponse"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelResponse:
|
||||||
|
"""Portable representation of a provider completion."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
|
model_name: str = ""
|
||||||
|
friendly_name: str = ""
|
||||||
|
provider: ProviderType = ProviderType.GOOGLE
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Return the total token count if the provider reported usage data."""
|
||||||
|
|
||||||
|
return self.usage.get("total_tokens", 0)
|
||||||
16
providers/shared/provider_type.py
Normal file
16
providers/shared/provider_type.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Enumeration describing which backend owns a given model."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
__all__ = ["ProviderType"]
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(Enum):
|
||||||
|
"""Canonical identifiers for every supported provider backend."""
|
||||||
|
|
||||||
|
GOOGLE = "google"
|
||||||
|
OPENAI = "openai"
|
||||||
|
XAI = "xai"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
DIAL = "dial"
|
||||||
121
providers/shared/temperature.py
Normal file
121
providers/shared/temperature.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Helper types for validating model temperature parameters."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TemperatureConstraint",
|
||||||
|
"FixedTemperatureConstraint",
|
||||||
|
"RangeTemperatureConstraint",
|
||||||
|
"DiscreteTemperatureConstraint",
|
||||||
|
"create_temperature_constraint",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TemperatureConstraint(ABC):
|
||||||
|
"""Contract for temperature validation used by `ModelCapabilities`.
|
||||||
|
|
||||||
|
Concrete providers describe their temperature behaviour by creating
|
||||||
|
subclasses that expose three operations:
|
||||||
|
* `validate` – decide whether a requested temperature is acceptable.
|
||||||
|
* `get_corrected_value` – coerce out-of-range values into a safe default.
|
||||||
|
* `get_description` – provide a human readable error message for users.
|
||||||
|
|
||||||
|
Providers call these hooks before sending traffic to the underlying API so
|
||||||
|
that unsupported temperatures never reach the remote service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
"""Return ``True`` when the temperature may be sent to the backend."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
"""Return a valid substitute for an out-of-range temperature."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str:
|
||||||
|
"""Describe the acceptable range to include in error messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_default(self) -> float:
|
||||||
|
"""Return the default temperature for the model."""
|
||||||
|
|
||||||
|
|
||||||
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
||||||
|
|
||||||
|
def __init__(self, value: float):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Only supports temperature={self.value}"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for providers that expose a continuous min/max temperature range."""
|
||||||
|
|
||||||
|
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
|
||||||
|
self.min_temp = min_temp
|
||||||
|
self.max_temp = max_temp
|
||||||
|
self.default_temp = default or (min_temp + max_temp) / 2
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return self.min_temp <= temperature <= self.max_temp
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return max(self.min_temp, min(self.max_temp, temperature))
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.default_temp
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for models that permit a discrete list of temperature values."""
|
||||||
|
|
||||||
|
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
|
||||||
|
self.allowed_values = sorted(allowed_values)
|
||||||
|
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Supports temperatures: {self.allowed_values}"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.default_temp
|
||||||
|
|
||||||
|
|
||||||
|
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
|
||||||
|
"""Factory that yields the appropriate constraint for a model configuration.
|
||||||
|
|
||||||
|
The JSON configuration stored in ``conf/custom_models.json`` references this
|
||||||
|
helper via human-readable strings. Providers feed those values into this
|
||||||
|
function so that runtime logic can rely on strongly typed constraint
|
||||||
|
objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if constraint_type == "fixed":
|
||||||
|
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||||
|
return FixedTemperatureConstraint(1.0)
|
||||||
|
if constraint_type == "discrete":
|
||||||
|
# For models with specific allowed values - using common OpenAI values as default
|
||||||
|
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||||
|
# Default range constraint (for "range" or None)
|
||||||
|
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
from .base import (
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import (
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
create_temperature_constraint,
|
create_temperature_constraint,
|
||||||
)
|
)
|
||||||
from .openai_compatible import OpenAICompatibleProvider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XAIModelProvider(OpenAICompatibleProvider):
|
class XAIModelProvider(OpenAICompatibleProvider):
|
||||||
"""X.AI GROK API provider (api.x.ai)."""
|
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
|
||||||
|
|
||||||
|
Publishes capability metadata for the officially supported deployments and
|
||||||
|
maps tool-category preferences to the appropriate GROK model.
|
||||||
|
"""
|
||||||
|
|
||||||
FRIENDLY_NAME = "X.AI"
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
|
|||||||
@@ -412,12 +412,12 @@ def configure_providers():
|
|||||||
value = os.getenv(key)
|
value = os.getenv(key)
|
||||||
logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}")
|
logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}")
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.custom import CustomProvider
|
from providers.custom import CustomProvider
|
||||||
from providers.dial import DIALModelProvider
|
from providers.dial import DIALModelProvider
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
from utils.model_restrictions import get_restriction_service
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
|||||||
@@ -34,9 +34,9 @@ if sys.platform == "win32":
|
|||||||
|
|
||||||
# Register providers for all tests
|
# Register providers for all tests
|
||||||
from providers import ModelProviderRegistry # noqa: E402
|
from providers import ModelProviderRegistry # noqa: E402
|
||||||
from providers.base import ProviderType # noqa: E402
|
|
||||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||||
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
||||||
|
from providers.shared import ProviderType # noqa: E402
|
||||||
from providers.xai import XAIModelProvider # noqa: E402
|
from providers.xai import XAIModelProvider # noqa: E402
|
||||||
|
|
||||||
# Register providers at test startup
|
# Register providers at test startup
|
||||||
@@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Ensure providers are registered (in case other tests cleared the registry)
|
# Ensure providers are registered (in case other tests cleared the registry)
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
registry = ModelProviderRegistry()
|
registry = ModelProviderRegistry()
|
||||||
|
|
||||||
@@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_model_restriction_env(monkeypatch):
|
||||||
|
"""Ensure per-test isolation from user-defined model restriction env vars."""
|
||||||
|
|
||||||
|
restriction_vars = [
|
||||||
|
"OPENAI_ALLOWED_MODELS",
|
||||||
|
"GOOGLE_ALLOWED_MODELS",
|
||||||
|
"XAI_ALLOWED_MODELS",
|
||||||
|
"OPENROUTER_ALLOWED_MODELS",
|
||||||
|
"DIAL_ALLOWED_MODELS",
|
||||||
|
]
|
||||||
|
|
||||||
|
for var in restriction_vars:
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||||
|
|
||||||
|
|
||||||
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):
|
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.analyze import AnalyzeTool
|
from tools.analyze import AnalyzeTool
|
||||||
from tools.chat import ChatTool
|
from tools.chat import ChatTool
|
||||||
from tools.debug import DebugIssueTool
|
from tools.debug import DebugIssueTool
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.no_mock_provider
|
@pytest.mark.no_mock_provider
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.models import ToolModelCategory
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
|
|
||||||
# Mock get_model_config to return our test model
|
# Mock get_model_config to return our test model
|
||||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
test_capabilities = ModelCapabilities(
|
test_capabilities = ModelCapabilities(
|
||||||
provider=ProviderType.OPENAI,
|
provider=ProviderType.OPENAI,
|
||||||
@@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
|
|
||||||
# Mock get_model_config to return a model that supports temperature
|
# Mock get_model_config to return a model that supports temperature
|
||||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
test_capabilities = ModelCapabilities(
|
test_capabilities = ModelCapabilities(
|
||||||
provider=ProviderType.OPENAI,
|
provider=ProviderType.OPENAI,
|
||||||
@@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
|||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
|
|
||||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
test_capabilities = ModelCapabilities(
|
test_capabilities = ModelCapabilities(
|
||||||
provider=ProviderType.OPENAI,
|
provider=ProviderType.OPENAI,
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers import ModelProviderRegistry
|
from providers import ModelProviderRegistry
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.custom import CustomProvider
|
from providers.custom import CustomProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestCustomProvider:
|
class TestCustomProvider:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.dial import DIALModelProvider
|
from providers.dial import DIALModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestDIALProvider:
|
class TestDIALProvider:
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
|
from providers.base import ModelProvider
|
||||||
|
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class MinimalTestProvider(ModelProvider):
|
class MinimalTestProvider(ModelProvider):
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestIntelligentFallback:
|
class TestIntelligentFallback:
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored():
|
|||||||
mock_registry = Mock()
|
mock_registry = Mock()
|
||||||
mock_registry_class.return_value = mock_registry
|
mock_registry_class.return_value = mock_registry
|
||||||
|
|
||||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||||
|
|
||||||
# This is what the user configured in their custom_models.json
|
# This is what the user configured in their custom_models.json
|
||||||
custom_config = ModelCapabilities(
|
custom_config = ModelCapabilities(
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from providers.base import ModelProvider, ProviderType
|
from providers.base import ModelProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.listmodels import ListModelsTool
|
from tools.listmodels import ListModelsTool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class TestModelEnumeration:
|
|||||||
|
|
||||||
# Rebuild the provider registry with OpenRouter registered
|
# Rebuild the provider registry with OpenRouter registered
|
||||||
ModelProviderRegistry._instance = None
|
ModelProviderRegistry._instance = None
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ This test specifically targets the bug where:
|
|||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.consensus import ConsensusTool
|
from tools.consensus import ConsensusTool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems.
|
|||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
from utils.model_restrictions import ModelRestrictionService
|
from utils.model_restrictions import ModelRestrictionService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIProvider:
|
class TestOpenAIProvider:
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.openrouter import OpenRouterProvider
|
from providers.openrouter import OpenRouterProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterProvider:
|
class TestOpenRouterProvider:
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import tempfile
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ModelCapabilities, ProviderType
|
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from providers.shared import ModelCapabilities, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestOpenRouterModelRegistry:
|
class TestOpenRouterModelRegistry:
|
||||||
@@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry:
|
|||||||
|
|
||||||
def test_model_with_all_capabilities(self):
|
def test_model_with_all_capabilities(self):
|
||||||
"""Test model with all capability flags."""
|
"""Test model with all capability flags."""
|
||||||
from providers.base import create_temperature_constraint
|
from providers.shared import create_temperature_constraint
|
||||||
|
|
||||||
caps = ModelCapabilities(
|
caps = ModelCapabilities(
|
||||||
provider=ProviderType.OPENROUTER,
|
provider=ProviderType.OPENROUTER,
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.chat import ChatTool
|
from tools.chat import ChatTool
|
||||||
from tools.shared.base_models import ToolRequest
|
from tools.shared.base_models import ToolRequest
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestProviderUTF8Encoding(unittest.TestCase):
|
class TestProviderUTF8Encoding(unittest.TestCase):
|
||||||
@@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
|
|||||||
|
|
||||||
def test_model_response_utf8_serialization(self):
|
def test_model_response_utf8_serialization(self):
|
||||||
"""Test UTF-8 serialization of model responses."""
|
"""Test UTF-8 serialization of model responses."""
|
||||||
from providers.base import ModelResponse
|
from providers.shared import ModelResponse
|
||||||
|
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
content="Development successful! Code generated successfully. 🎉✅",
|
content="Development successful! Code generated successfully. 🎉✅",
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ from unittest.mock import Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers import ModelProviderRegistry, ModelResponse
|
from providers import ModelProviderRegistry, ModelResponse
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.gemini import GeminiModelProvider
|
from providers.gemini import GeminiModelProvider
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class TestModelProviderRegistry:
|
class TestModelProviderRegistry:
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class TestSupportedModelsAliases:
|
|||||||
for provider in providers:
|
for provider in providers:
|
||||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||||
# All values must be ModelCapabilities objects, not strings or dicts
|
# All values must be ModelCapabilities objects, not strings or dicts
|
||||||
from providers.base import ModelCapabilities
|
from providers.shared import ModelCapabilities
|
||||||
|
|
||||||
assert isinstance(config, ModelCapabilities), (
|
assert isinstance(config, ModelCapabilities), (
|
||||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
from tools.debug import DebugIssueTool
|
from tools.debug import DebugIssueTool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
from providers.xai import XAIModelProvider
|
from providers.xai import XAIModelProvider
|
||||||
|
|
||||||
|
|
||||||
@@ -265,7 +265,7 @@ class TestXAIProvider:
|
|||||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||||
|
|
||||||
# Check model configs have required fields
|
# Check model configs have required fields
|
||||||
from providers.base import ModelCapabilities
|
from providers.shared import ModelCapabilities
|
||||||
|
|
||||||
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
||||||
assert isinstance(grok4_config, ModelCapabilities)
|
assert isinstance(grok4_config, ModelCapabilities)
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str):
|
|||||||
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
||||||
"""
|
"""
|
||||||
# Ensure OpenAI provider is registered - always needed for transport injection
|
# Ensure OpenAI provider is registered - always needed for transport injection
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.openai_provider import OpenAIModelProvider
|
from providers.openai_provider import OpenAIModelProvider
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
# Always register OpenAI provider for transport tests (API key might be dummy)
|
# Always register OpenAI provider for transport tests (API key might be dummy)
|
||||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||||
|
|||||||
@@ -79,9 +79,9 @@ class ListModelsTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted list of models by provider
|
Formatted list of models by provider
|
||||||
"""
|
"""
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
output_lines = ["# Available AI Models\n"]
|
output_lines = ["# Available AI Models\n"]
|
||||||
|
|
||||||
@@ -162,8 +162,8 @@ class ListModelsTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get OpenRouter provider from registry to properly apply restrictions
|
# Get OpenRouter provider from registry to properly apply restrictions
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
if provider:
|
if provider:
|
||||||
|
|||||||
@@ -1341,7 +1341,7 @@ When recommending searches, be specific about what information you need and why
|
|||||||
# Apply 40MB cap for custom models if needed
|
# Apply 40MB cap for custom models if needed
|
||||||
effective_limit_mb = max_size_mb
|
effective_limit_mb = max_size_mb
|
||||||
try:
|
try:
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
# ModelCapabilities dataclass has provider field defined
|
# ModelCapabilities dataclass has provider field defined
|
||||||
if capabilities.provider == ProviderType.CUSTOM:
|
if capabilities.provider == ProviderType.CUSTOM:
|
||||||
|
|||||||
@@ -306,8 +306,8 @@ class VersionTool(BaseTool):
|
|||||||
|
|
||||||
# Check for configured providers
|
# Check for configured providers
|
||||||
try:
|
try:
|
||||||
from providers.base import ProviderType
|
|
||||||
from providers.registry import ModelProviderRegistry
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
provider_status = []
|
provider_status = []
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from providers.base import ProviderType
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user