diff --git a/providers/__init__.py b/providers/__init__.py index ffeecb6..311fafa 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -1,11 +1,12 @@ """Model provider abstractions for supporting multiple AI providers.""" -from .base import ModelCapabilities, ModelProvider, ModelResponse +from .base import ModelProvider from .gemini import GeminiModelProvider from .openai_compatible import OpenAICompatibleProvider from .openai_provider import OpenAIModelProvider from .openrouter import OpenRouterProvider from .registry import ModelProviderRegistry +from .shared import ModelCapabilities, ModelResponse __all__ = [ "ModelProvider", diff --git a/providers/base.py b/providers/base.py index b0dcdce..ff290aa 100644 --- a/providers/base.py +++ b/providers/base.py @@ -1,12 +1,10 @@ -"""Base model provider interface and data classes.""" +"""Base interfaces and common behaviour for model providers.""" import base64 import binascii import logging import os from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: @@ -14,179 +12,20 @@ if TYPE_CHECKING: from utils.file_types import IMAGES, get_image_mime_type +from .shared import ModelCapabilities, ModelResponse, ProviderType + logger = logging.getLogger(__name__) -class ProviderType(Enum): - """Supported model provider types.""" - - GOOGLE = "google" - OPENAI = "openai" - XAI = "xai" - OPENROUTER = "openrouter" - CUSTOM = "custom" - DIAL = "dial" - - -class TemperatureConstraint(ABC): - """Abstract base class for temperature constraints.""" - - @abstractmethod - def validate(self, temperature: float) -> bool: - """Check if temperature is valid.""" - pass - - @abstractmethod - def get_corrected_value(self, temperature: float) -> float: - """Get nearest valid temperature.""" - pass - - @abstractmethod - def get_description(self) -> str: - """Get human-readable description of constraint.""" - pass - - @abstractmethod - def get_default(self) -> float: - """Get model's default temperature.""" - pass - - -class FixedTemperatureConstraint(TemperatureConstraint): - """For models that only support one temperature value (e.g., O3).""" - - def __init__(self, value: float): - self.value = value - - def validate(self, temperature: float) -> bool: - return abs(temperature - self.value) < 1e-6 # Handle floating point precision - - def get_corrected_value(self, temperature: float) -> float: - return self.value - - def get_description(self) -> str: - return f"Only supports temperature={self.value}" - - def get_default(self) -> float: - return self.value - - -class RangeTemperatureConstraint(TemperatureConstraint): - """For models supporting continuous temperature ranges.""" - - def __init__(self, min_temp: float, max_temp: float, default: float = None): - self.min_temp = min_temp - self.max_temp = max_temp - self.default_temp = default or (min_temp + max_temp) / 2 - - def validate(self, temperature: float) -> bool: - return self.min_temp <= temperature <= self.max_temp - - def get_corrected_value(self, temperature: float) -> float: - return max(self.min_temp, min(self.max_temp, temperature)) - - def get_description(self) -> str: - return f"Supports temperature range [{self.min_temp}, {self.max_temp}]" - - def get_default(self) -> float: - return self.default_temp - - -class DiscreteTemperatureConstraint(TemperatureConstraint): - """For models supporting only specific temperature values.""" - - def __init__(self, allowed_values: list[float], default: float = None): - self.allowed_values = sorted(allowed_values) - self.default_temp = default or allowed_values[len(allowed_values) // 2] - - def validate(self, temperature: float) -> bool: - return any(abs(temperature - val) < 1e-6 for val in self.allowed_values) - - def get_corrected_value(self, temperature: float) -> float: - return min(self.allowed_values, key=lambda x: abs(x - temperature)) - - def get_description(self) -> str: - return f"Supports temperatures: {self.allowed_values}" - - def get_default(self) -> float: - return self.default_temp - - -def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint: - """Create temperature constraint object from configuration string. - - Args: - constraint_type: Type of constraint ("fixed", "range", "discrete") - - Returns: - TemperatureConstraint object based on configuration - """ - if constraint_type == "fixed": - # Fixed temperature models (O3/O4) only support temperature=1.0 - return FixedTemperatureConstraint(1.0) - elif constraint_type == "discrete": - # For models with specific allowed values - using common OpenAI values as default - return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3) - else: - # Default range constraint (for "range" or None) - return RangeTemperatureConstraint(0.0, 2.0, 0.3) - - -@dataclass -class ModelCapabilities: - """Capabilities and constraints for a specific model.""" - - provider: ProviderType - model_name: str - friendly_name: str # Human-friendly name like "Gemini" or "OpenAI" - context_window: int # Total context window size in tokens - max_output_tokens: int # Maximum output tokens per request - supports_extended_thinking: bool = False - supports_system_prompts: bool = True - supports_streaming: bool = True - supports_function_calling: bool = False - supports_images: bool = False # Whether model can process images - max_image_size_mb: float = 0.0 # Maximum total size for all images in MB - supports_temperature: bool = True # Whether model accepts temperature parameter in API calls - - # Additional fields for comprehensive model information - description: str = "" # Human-readable description of the model - aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model - - # JSON mode support (for providers that support structured output) - supports_json_mode: bool = False - - # Thinking mode support (for models with thinking capabilities) - max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models - - # Custom model flag (for models that only work with custom endpoints) - is_custom: bool = False # Whether this model requires custom API endpoints - - # Temperature constraint object - defines temperature limits and behavior - temperature_constraint: TemperatureConstraint = field( - default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) - ) - - -@dataclass -class ModelResponse: - """Response from a model provider.""" - - content: str - usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens - model_name: str = "" - friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI" - provider: ProviderType = ProviderType.GOOGLE - metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata - - @property - def total_tokens(self) -> int: - """Get total tokens used.""" - return self.usage.get("total_tokens", 0) - - class ModelProvider(ABC): - """Abstract base class for model providers.""" + """Defines the contract implemented by every model provider backend. + + Subclasses adapt third-party SDKs into the MCP server by exposing + capability metadata, request execution, and token counting through a + consistent interface. Shared helper methods (temperature validation, + alias resolution, image handling, etc.) live here so individual providers + only need to focus on provider-specific details. + """ # All concrete providers must define their supported models SUPPORTED_MODELS: dict[str, Any] = {} diff --git a/providers/custom.py b/providers/custom.py index 64c5d68..d7c4f37 100644 --- a/providers/custom.py +++ b/providers/custom.py @@ -4,15 +4,15 @@ import logging import os from typing import Optional -from .base import ( +from .openai_compatible import OpenAICompatibleProvider +from .openrouter_registry import OpenRouterModelRegistry +from .shared import ( FixedTemperatureConstraint, ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint, ) -from .openai_compatible import OpenAICompatibleProvider -from .openrouter_registry import OpenRouterModelRegistry # Temperature inference patterns _TEMP_UNSUPPORTED_PATTERNS = [ @@ -30,10 +30,13 @@ _TEMP_UNSUPPORTED_KEYWORDS = [ class CustomProvider(OpenAICompatibleProvider): - """Custom API provider for local models. + """Adapter for self-hosted or local OpenAI-compatible endpoints. - Supports local inference servers like Ollama, vLLM, LM Studio, - and any OpenAI-compatible API endpoint. + The provider reuses the :mod:`providers.shared` registry to surface + user-defined aliases and capability metadata. It also normalises + Ollama-style version tags (``model:latest``) and enforces the same + restriction policies used by cloud providers, ensuring consistent + behaviour regardless of where the model is hosted. """ FRIENDLY_NAME = "Custom API" diff --git a/providers/dial.py b/providers/dial.py index dffbfba..8ca5b9c 100644 --- a/providers/dial.py +++ b/providers/dial.py @@ -6,22 +6,24 @@ import threading import time from typing import Optional -from .base import ( +from .openai_compatible import OpenAICompatibleProvider +from .shared import ( ModelCapabilities, ModelResponse, ProviderType, create_temperature_constraint, ) -from .openai_compatible import OpenAICompatibleProvider logger = logging.getLogger(__name__) class DIALModelProvider(OpenAICompatibleProvider): - """DIAL provider using OpenAI-compatible API. + """Client for the DIAL (Data & AI Layer) aggregation service. - DIAL provides access to various AI models through a unified API interface. - Supports GPT, Claude, Gemini, and other models via DIAL deployments. + DIAL exposes several third-party models behind a single OpenAI-compatible + endpoint. This provider wraps the service, publishes capability metadata + for the known deployments, and centralises retry/backoff settings tailored + to DIAL's latency characteristics. """ FRIENDLY_NAME = "DIAL" diff --git a/providers/gemini.py b/providers/gemini.py index 0cab004..9f2bc26 100644 --- a/providers/gemini.py +++ b/providers/gemini.py @@ -11,13 +11,24 @@ if TYPE_CHECKING: from google import genai from google.genai import types -from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint +from .base import ModelProvider +from .shared import ( + ModelCapabilities, + ModelResponse, + ProviderType, + create_temperature_constraint, +) logger = logging.getLogger(__name__) class GeminiModelProvider(ModelProvider): - """Google Gemini model provider implementation.""" + """First-party Gemini integration built on the official Google SDK. + + The provider advertises detailed thinking-mode budgets, handles optional + custom endpoints, and performs image pre-processing before forwarding a + request to the Gemini APIs. + """ # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py index e76f727..701c84f 100644 --- a/providers/openai_compatible.py +++ b/providers/openai_compatible.py @@ -11,21 +11,21 @@ from urllib.parse import urlparse from openai import OpenAI -from .base import ( +from .base import ModelProvider +from .shared import ( ModelCapabilities, - ModelProvider, ModelResponse, ProviderType, ) class OpenAICompatibleProvider(ModelProvider): - """Base class for any provider using an OpenAI-compatible API. + """Shared implementation for OpenAI API lookalikes. - This includes: - - Direct OpenAI API - - OpenRouter - - Any other OpenAI-compatible endpoint + The class owns HTTP client configuration (timeouts, proxy hardening, + custom headers) and normalises the OpenAI SDK responses into + :class:`~providers.shared.ModelResponse`. Concrete subclasses only need to + provide capability metadata and any provider-specific request tweaks. """ DEFAULT_HEADERS = {} diff --git a/providers/openai_provider.py b/providers/openai_provider.py index 81bb067..55cb657 100644 --- a/providers/openai_provider.py +++ b/providers/openai_provider.py @@ -6,19 +6,24 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory -from .base import ( +from .openai_compatible import OpenAICompatibleProvider +from .shared import ( ModelCapabilities, ModelResponse, ProviderType, create_temperature_constraint, ) -from .openai_compatible import OpenAICompatibleProvider logger = logging.getLogger(__name__) class OpenAIModelProvider(OpenAICompatibleProvider): - """Official OpenAI API provider (api.openai.com).""" + """Implementation that talks to api.openai.com using rich model metadata. + + In addition to the built-in catalogue, the provider can surface models + defined in ``conf/custom_models.json`` (for organisations running their own + OpenAI-compatible gateways) while still respecting restriction policies. + """ # Model configurations using ModelCapabilities objects SUPPORTED_MODELS = { diff --git a/providers/openrouter.py b/providers/openrouter.py index c0ed58d..5360b87 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -4,21 +4,22 @@ import logging import os from typing import Optional -from .base import ( +from .openai_compatible import OpenAICompatibleProvider +from .openrouter_registry import OpenRouterModelRegistry +from .shared import ( ModelCapabilities, ModelResponse, ProviderType, RangeTemperatureConstraint, ) -from .openai_compatible import OpenAICompatibleProvider -from .openrouter_registry import OpenRouterModelRegistry class OpenRouterProvider(OpenAICompatibleProvider): - """OpenRouter unified API provider. + """Client for OpenRouter's multi-model aggregation service. - OpenRouter provides access to multiple AI models through a single API endpoint. - See https://openrouter.ai for available models and pricing. + OpenRouter surfaces dozens of upstream vendors. This provider layers alias + resolution, restriction-aware filtering, and sensible capability defaults + on top of the generic OpenAI-compatible plumbing. """ FRIENDLY_NAME = "OpenRouter" diff --git a/providers/openrouter_registry.py b/providers/openrouter_registry.py index 949e2c8..9e1dbf1 100644 --- a/providers/openrouter_registry.py +++ b/providers/openrouter_registry.py @@ -9,7 +9,7 @@ from typing import Optional # Import handled via importlib.resources.files() calls directly from utils.file_utils import read_json_file -from .base import ( +from .shared import ( ModelCapabilities, ProviderType, create_temperature_constraint, @@ -17,7 +17,13 @@ from .base import ( class OpenRouterModelRegistry: - """Registry for managing OpenRouter model configurations and aliases.""" + """Loads and validates the OpenRouter/custom model catalogue. + + The registry parses ``conf/custom_models.json`` (or an override supplied via + environment variable), builds case-insensitive alias maps, and exposes + :class:`~providers.shared.ModelCapabilities` objects used by several + providers. + """ def __init__(self, config_path: Optional[str] = None): """Initialize the registry. @@ -263,6 +269,11 @@ class OpenRouterModelRegistry: # Registry now returns ModelCapabilities directly return self.resolve(name_or_alias) + def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]: + """Backward-compatible wrapper used by providers and older tests.""" + + return self.resolve(name_or_alias) + def list_models(self) -> list[str]: """List all available model names.""" return list(self.model_map.keys()) diff --git a/providers/registry.py b/providers/registry.py index 7a1b94e..c22cfcf 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -4,14 +4,20 @@ import logging import os from typing import TYPE_CHECKING, Optional -from .base import ModelProvider, ProviderType +from .base import ModelProvider +from .shared import ProviderType if TYPE_CHECKING: from tools.models import ToolModelCategory class ModelProviderRegistry: - """Registry for managing model providers.""" + """Singleton that caches provider instances and coordinates priority order. + + Responsibilities include resolving API keys from the environment, lazily + instantiating providers, and choosing the best provider for a model based + on restriction policies and provider priority. + """ _instance = None diff --git a/providers/shared/__init__.py b/providers/shared/__init__.py new file mode 100644 index 0000000..2ed6072 --- /dev/null +++ b/providers/shared/__init__.py @@ -0,0 +1,23 @@ +"""Shared data structures and helpers for model providers.""" + +from .model_capabilities import ModelCapabilities +from .model_response import ModelResponse +from .provider_type import ProviderType +from .temperature import ( + DiscreteTemperatureConstraint, + FixedTemperatureConstraint, + RangeTemperatureConstraint, + TemperatureConstraint, + create_temperature_constraint, +) + +__all__ = [ + "ModelCapabilities", + "ModelResponse", + "ProviderType", + "TemperatureConstraint", + "FixedTemperatureConstraint", + "RangeTemperatureConstraint", + "DiscreteTemperatureConstraint", + "create_temperature_constraint", +] diff --git a/providers/shared/model_capabilities.py b/providers/shared/model_capabilities.py new file mode 100644 index 0000000..f68d304 --- /dev/null +++ b/providers/shared/model_capabilities.py @@ -0,0 +1,34 @@ +"""Dataclass describing the feature set of a model exposed by a provider.""" + +from dataclasses import dataclass, field + +from .provider_type import ProviderType +from .temperature import RangeTemperatureConstraint, TemperatureConstraint + +__all__ = ["ModelCapabilities"] + + +@dataclass +class ModelCapabilities: + """Static capabilities and constraints for a provider-managed model.""" + + provider: ProviderType + model_name: str + friendly_name: str + context_window: int + max_output_tokens: int + supports_extended_thinking: bool = False + supports_system_prompts: bool = True + supports_streaming: bool = True + supports_function_calling: bool = False + supports_images: bool = False + max_image_size_mb: float = 0.0 + supports_temperature: bool = True + description: str = "" + aliases: list[str] = field(default_factory=list) + supports_json_mode: bool = False + max_thinking_tokens: int = 0 + is_custom: bool = False + temperature_constraint: TemperatureConstraint = field( + default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) + ) diff --git a/providers/shared/model_response.py b/providers/shared/model_response.py new file mode 100644 index 0000000..cccff48 --- /dev/null +++ b/providers/shared/model_response.py @@ -0,0 +1,26 @@ +"""Dataclass used to normalise provider SDK responses.""" + +from dataclasses import dataclass, field +from typing import Any + +from .provider_type import ProviderType + +__all__ = ["ModelResponse"] + + +@dataclass +class ModelResponse: + """Portable representation of a provider completion.""" + + content: str + usage: dict[str, int] = field(default_factory=dict) + model_name: str = "" + friendly_name: str = "" + provider: ProviderType = ProviderType.GOOGLE + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def total_tokens(self) -> int: + """Return the total token count if the provider reported usage data.""" + + return self.usage.get("total_tokens", 0) diff --git a/providers/shared/provider_type.py b/providers/shared/provider_type.py new file mode 100644 index 0000000..44153f0 --- /dev/null +++ b/providers/shared/provider_type.py @@ -0,0 +1,16 @@ +"""Enumeration describing which backend owns a given model.""" + +from enum import Enum + +__all__ = ["ProviderType"] + + +class ProviderType(Enum): + """Canonical identifiers for every supported provider backend.""" + + GOOGLE = "google" + OPENAI = "openai" + XAI = "xai" + OPENROUTER = "openrouter" + CUSTOM = "custom" + DIAL = "dial" diff --git a/providers/shared/temperature.py b/providers/shared/temperature.py new file mode 100644 index 0000000..6c6c9af --- /dev/null +++ b/providers/shared/temperature.py @@ -0,0 +1,121 @@ +"""Helper types for validating model temperature parameters.""" + +from abc import ABC, abstractmethod +from typing import Optional + +__all__ = [ + "TemperatureConstraint", + "FixedTemperatureConstraint", + "RangeTemperatureConstraint", + "DiscreteTemperatureConstraint", + "create_temperature_constraint", +] + + +class TemperatureConstraint(ABC): + """Contract for temperature validation used by `ModelCapabilities`. + + Concrete providers describe their temperature behaviour by creating + subclasses that expose three operations: + * `validate` – decide whether a requested temperature is acceptable. + * `get_corrected_value` – coerce out-of-range values into a safe default. + * `get_description` – provide a human readable error message for users. + + Providers call these hooks before sending traffic to the underlying API so + that unsupported temperatures never reach the remote service. + """ + + @abstractmethod + def validate(self, temperature: float) -> bool: + """Return ``True`` when the temperature may be sent to the backend.""" + + @abstractmethod + def get_corrected_value(self, temperature: float) -> float: + """Return a valid substitute for an out-of-range temperature.""" + + @abstractmethod + def get_description(self) -> str: + """Describe the acceptable range to include in error messages.""" + + @abstractmethod + def get_default(self) -> float: + """Return the default temperature for the model.""" + + +class FixedTemperatureConstraint(TemperatureConstraint): + """Constraint for models that enforce an exact temperature (for example O3).""" + + def __init__(self, value: float): + self.value = value + + def validate(self, temperature: float) -> bool: + return abs(temperature - self.value) < 1e-6 # Handle floating point precision + + def get_corrected_value(self, temperature: float) -> float: + return self.value + + def get_description(self) -> str: + return f"Only supports temperature={self.value}" + + def get_default(self) -> float: + return self.value + + +class RangeTemperatureConstraint(TemperatureConstraint): + """Constraint for providers that expose a continuous min/max temperature range.""" + + def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None): + self.min_temp = min_temp + self.max_temp = max_temp + self.default_temp = default or (min_temp + max_temp) / 2 + + def validate(self, temperature: float) -> bool: + return self.min_temp <= temperature <= self.max_temp + + def get_corrected_value(self, temperature: float) -> float: + return max(self.min_temp, min(self.max_temp, temperature)) + + def get_description(self) -> str: + return f"Supports temperature range [{self.min_temp}, {self.max_temp}]" + + def get_default(self) -> float: + return self.default_temp + + +class DiscreteTemperatureConstraint(TemperatureConstraint): + """Constraint for models that permit a discrete list of temperature values.""" + + def __init__(self, allowed_values: list[float], default: Optional[float] = None): + self.allowed_values = sorted(allowed_values) + self.default_temp = default or allowed_values[len(allowed_values) // 2] + + def validate(self, temperature: float) -> bool: + return any(abs(temperature - val) < 1e-6 for val in self.allowed_values) + + def get_corrected_value(self, temperature: float) -> float: + return min(self.allowed_values, key=lambda x: abs(x - temperature)) + + def get_description(self) -> str: + return f"Supports temperatures: {self.allowed_values}" + + def get_default(self) -> float: + return self.default_temp + + +def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint: + """Factory that yields the appropriate constraint for a model configuration. + + The JSON configuration stored in ``conf/custom_models.json`` references this + helper via human-readable strings. Providers feed those values into this + function so that runtime logic can rely on strongly typed constraint + objects. + """ + + if constraint_type == "fixed": + # Fixed temperature models (O3/O4) only support temperature=1.0 + return FixedTemperatureConstraint(1.0) + if constraint_type == "discrete": + # For models with specific allowed values - using common OpenAI values as default + return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3) + # Default range constraint (for "range" or None) + return RangeTemperatureConstraint(0.0, 2.0, 0.3) diff --git a/providers/xai.py b/providers/xai.py index f2b8242..1d3e5db 100644 --- a/providers/xai.py +++ b/providers/xai.py @@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tools.models import ToolModelCategory -from .base import ( +from .openai_compatible import OpenAICompatibleProvider +from .shared import ( ModelCapabilities, ModelResponse, ProviderType, create_temperature_constraint, ) -from .openai_compatible import OpenAICompatibleProvider logger = logging.getLogger(__name__) class XAIModelProvider(OpenAICompatibleProvider): - """X.AI GROK API provider (api.x.ai).""" + """Integration for X.AI's GROK models exposed over an OpenAI-style API. + + Publishes capability metadata for the officially supported deployments and + maps tool-category preferences to the appropriate GROK model. + """ FRIENDLY_NAME = "X.AI" diff --git a/server.py b/server.py index ade46df..12a5f65 100644 --- a/server.py +++ b/server.py @@ -412,12 +412,12 @@ def configure_providers(): value = os.getenv(key) logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}") from providers import ModelProviderRegistry - from providers.base import ProviderType from providers.custom import CustomProvider from providers.dial import DIALModelProvider from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider from providers.openrouter import OpenRouterProvider + from providers.shared import ProviderType from providers.xai import XAIModelProvider from utils.model_restrictions import get_restriction_service diff --git a/tests/conftest.py b/tests/conftest.py index 77af58a..d7e7768 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,9 +34,9 @@ if sys.platform == "win32": # Register providers for all tests from providers import ModelProviderRegistry # noqa: E402 -from providers.base import ProviderType # noqa: E402 from providers.gemini import GeminiModelProvider # noqa: E402 from providers.openai_provider import OpenAIModelProvider # noqa: E402 +from providers.shared import ProviderType # noqa: E402 from providers.xai import XAIModelProvider # noqa: E402 # Register providers at test startup @@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch): return # Ensure providers are registered (in case other tests cleared the registry) - from providers.base import ProviderType + from providers.shared import ProviderType registry = ModelProviderRegistry() @@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch): return False monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode) + + +@pytest.fixture(autouse=True) +def clear_model_restriction_env(monkeypatch): + """Ensure per-test isolation from user-defined model restriction env vars.""" + + restriction_vars = [ + "OPENAI_ALLOWED_MODELS", + "GOOGLE_ALLOWED_MODELS", + "XAI_ALLOWED_MODELS", + "OPENROUTER_ALLOWED_MODELS", + "DIAL_ALLOWED_MODELS", + ] + + for var in restriction_vars: + monkeypatch.delenv(var, raising=False) diff --git a/tests/mock_helpers.py b/tests/mock_helpers.py index 1122af1..6ecf90f 100644 --- a/tests/mock_helpers.py +++ b/tests/mock_helpers.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint +from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576): diff --git a/tests/test_alias_target_restrictions.py b/tests/test_alias_target_restrictions.py index 3f417b8..83ebeff 100644 --- a/tests/test_alias_target_restrictions.py +++ b/tests/test_alias_target_restrictions.py @@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti import os from unittest.mock import patch -from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType from utils.model_restrictions import ModelRestrictionService diff --git a/tests/test_auto_mode_comprehensive.py b/tests/test_auto_mode_comprehensive.py index a68db41..6134ba8 100644 --- a/tests/test_auto_mode_comprehensive.py +++ b/tests/test_auto_mode_comprehensive.py @@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType from tools.analyze import AnalyzeTool from tools.chat import ChatTool from tools.debug import DebugIssueTool diff --git a/tests/test_auto_mode_custom_provider_only.py b/tests/test_auto_mode_custom_provider_only.py index c97e649..1ee53d2 100644 --- a/tests/test_auto_mode_custom_provider_only.py +++ b/tests/test_auto_mode_custom_provider_only.py @@ -6,8 +6,8 @@ from unittest.mock import patch import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType @pytest.mark.no_mock_provider diff --git a/tests/test_auto_mode_provider_selection.py b/tests/test_auto_mode_provider_selection.py index 9c47815..d59e71c 100644 --- a/tests/test_auto_mode_provider_selection.py +++ b/tests/test_auto_mode_provider_selection.py @@ -4,8 +4,8 @@ import os import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType from tools.models import ToolModelCategory diff --git a/tests/test_buggy_behavior_prevention.py b/tests/test_buggy_behavior_prevention.py index 1d07d2e..57cf204 100644 --- a/tests/test_buggy_behavior_prevention.py +++ b/tests/test_buggy_behavior_prevention.py @@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch import pytest -from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType from utils.model_restrictions import ModelRestrictionService diff --git a/tests/test_custom_openai_temperature_fix.py b/tests/test_custom_openai_temperature_fix.py index b634a17..fe168f6 100644 --- a/tests/test_custom_openai_temperature_fix.py +++ b/tests/test_custom_openai_temperature_fix.py @@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix: mock_registry_class.return_value = mock_registry # Mock get_model_config to return our test model - from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint + from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint test_capabilities = ModelCapabilities( provider=ProviderType.OPENAI, @@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix: mock_registry_class.return_value = mock_registry # Mock get_model_config to return a model that supports temperature - from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint + from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint test_capabilities = ModelCapabilities( provider=ProviderType.OPENAI, @@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix: mock_registry = Mock() mock_registry_class.return_value = mock_registry - from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint + from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint test_capabilities = ModelCapabilities( provider=ProviderType.OPENAI, diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py index 125417d..19b67aa 100644 --- a/tests/test_custom_provider.py +++ b/tests/test_custom_provider.py @@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch import pytest from providers import ModelProviderRegistry -from providers.base import ProviderType from providers.custom import CustomProvider +from providers.shared import ProviderType class TestCustomProvider: diff --git a/tests/test_dial_provider.py b/tests/test_dial_provider.py index 0d8d70f..3423c7c 100644 --- a/tests/test_dial_provider.py +++ b/tests/test_dial_provider.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch import pytest -from providers.base import ProviderType from providers.dial import DIALModelProvider +from providers.shared import ProviderType class TestDIALProvider: diff --git a/tests/test_image_validation.py b/tests/test_image_validation.py index e1fb36d..12dbcd5 100644 --- a/tests/test_image_validation.py +++ b/tests/test_image_validation.py @@ -8,7 +8,8 @@ from unittest.mock import Mock, patch import pytest -from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType +from providers.base import ModelProvider +from providers.shared import ModelCapabilities, ModelResponse, ProviderType class MinimalTestProvider(ModelProvider): diff --git a/tests/test_intelligent_fallback.py b/tests/test_intelligent_fallback.py index 8ad3b17..20aed4b 100644 --- a/tests/test_intelligent_fallback.py +++ b/tests/test_intelligent_fallback.py @@ -9,8 +9,8 @@ from unittest.mock import Mock, patch import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType class TestIntelligentFallback: diff --git a/tests/test_issue_245_simple.py b/tests/test_issue_245_simple.py index bd58ce8..647de02 100644 --- a/tests/test_issue_245_simple.py +++ b/tests/test_issue_245_simple.py @@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored(): mock_registry = Mock() mock_registry_class.return_value = mock_registry - from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint + from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint # This is what the user configured in their custom_models.json custom_config = ModelCapabilities( diff --git a/tests/test_listmodels_restrictions.py b/tests/test_listmodels_restrictions.py index 5d9f06d..82fe506 100644 --- a/tests/test_listmodels_restrictions.py +++ b/tests/test_listmodels_restrictions.py @@ -5,8 +5,9 @@ import os import unittest from unittest.mock import MagicMock, patch -from providers.base import ModelProvider, ProviderType +from providers.base import ModelProvider from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType from tools.listmodels import ListModelsTool diff --git a/tests/test_model_enumeration.py b/tests/test_model_enumeration.py index 6dc390b..0b95154 100644 --- a/tests/test_model_enumeration.py +++ b/tests/test_model_enumeration.py @@ -214,7 +214,7 @@ class TestModelEnumeration: # Rebuild the provider registry with OpenRouter registered ModelProviderRegistry._instance = None - from providers.base import ProviderType + from providers.shared import ProviderType ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) diff --git a/tests/test_model_resolution_bug.py b/tests/test_model_resolution_bug.py index ab92624..5db3bd8 100644 --- a/tests/test_model_resolution_bug.py +++ b/tests/test_model_resolution_bug.py @@ -9,8 +9,8 @@ This test specifically targets the bug where: from unittest.mock import Mock, patch -from providers.base import ProviderType from providers.openrouter import OpenRouterProvider +from providers.shared import ProviderType from tools.consensus import ConsensusTool diff --git a/tests/test_model_restrictions.py b/tests/test_model_restrictions.py index bf83f61..417ba07 100644 --- a/tests/test_model_restrictions.py +++ b/tests/test_model_restrictions.py @@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch import pytest -from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType from utils.model_restrictions import ModelRestrictionService diff --git a/tests/test_old_behavior_simulation.py b/tests/test_old_behavior_simulation.py index d14c8ea..2918183 100644 --- a/tests/test_old_behavior_simulation.py +++ b/tests/test_old_behavior_simulation.py @@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems. from unittest.mock import MagicMock, patch -from providers.base import ProviderType +from providers.shared import ProviderType from utils.model_restrictions import ModelRestrictionService diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index 5278ff5..752935e 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -3,8 +3,8 @@ import os from unittest.mock import MagicMock, patch -from providers.base import ProviderType from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType class TestOpenAIProvider: diff --git a/tests/test_openrouter_provider.py b/tests/test_openrouter_provider.py index ddbfdde..057bfde 100644 --- a/tests/test_openrouter_provider.py +++ b/tests/test_openrouter_provider.py @@ -5,9 +5,9 @@ from unittest.mock import Mock, patch import pytest -from providers.base import ProviderType from providers.openrouter import OpenRouterProvider from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType class TestOpenRouterProvider: diff --git a/tests/test_openrouter_registry.py b/tests/test_openrouter_registry.py index 60ec491..2a172b2 100644 --- a/tests/test_openrouter_registry.py +++ b/tests/test_openrouter_registry.py @@ -6,8 +6,8 @@ import tempfile import pytest -from providers.base import ModelCapabilities, ProviderType from providers.openrouter_registry import OpenRouterModelRegistry +from providers.shared import ModelCapabilities, ProviderType class TestOpenRouterModelRegistry: @@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry: def test_model_with_all_capabilities(self): """Test model with all capability flags.""" - from providers.base import create_temperature_constraint + from providers.shared import create_temperature_constraint caps = ModelCapabilities( provider=ProviderType.OPENROUTER, diff --git a/tests/test_provider_routing_bugs.py b/tests/test_provider_routing_bugs.py index 1e1363c..dec2f83 100644 --- a/tests/test_provider_routing_bugs.py +++ b/tests/test_provider_routing_bugs.py @@ -13,8 +13,8 @@ from unittest.mock import Mock import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType from tools.chat import ChatTool from tools.shared.base_models import ToolRequest diff --git a/tests/test_provider_utf8.py b/tests/test_provider_utf8.py index fc630e9..32d7571 100644 --- a/tests/test_provider_utf8.py +++ b/tests/test_provider_utf8.py @@ -10,9 +10,9 @@ from unittest.mock import Mock, patch import pytest -from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType class TestProviderUTF8Encoding(unittest.TestCase): @@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase): def test_model_response_utf8_serialization(self): """Test UTF-8 serialization of model responses.""" - from providers.base import ModelResponse + from providers.shared import ModelResponse response = ModelResponse( content="Development successful! Code generated successfully. 🎉✅", diff --git a/tests/test_providers.py b/tests/test_providers.py index 036ae9b..c9534b5 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -6,9 +6,9 @@ from unittest.mock import Mock, patch import pytest from providers import ModelProviderRegistry, ModelResponse -from providers.base import ProviderType from providers.gemini import GeminiModelProvider from providers.openai_provider import OpenAIModelProvider +from providers.shared import ProviderType class TestModelProviderRegistry: diff --git a/tests/test_supported_models_aliases.py b/tests/test_supported_models_aliases.py index 336368b..1dea8d3 100644 --- a/tests/test_supported_models_aliases.py +++ b/tests/test_supported_models_aliases.py @@ -185,7 +185,7 @@ class TestSupportedModelsAliases: for provider in providers: for model_name, config in provider.SUPPORTED_MODELS.items(): # All values must be ModelCapabilities objects, not strings or dicts - from providers.base import ModelCapabilities + from providers.shared import ModelCapabilities assert isinstance(config, ModelCapabilities), ( f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] " diff --git a/tests/test_workflow_metadata.py b/tests/test_workflow_metadata.py index d0a9693..0d0e870 100644 --- a/tests/test_workflow_metadata.py +++ b/tests/test_workflow_metadata.py @@ -10,8 +10,8 @@ import os import pytest -from providers.base import ProviderType from providers.registry import ModelProviderRegistry +from providers.shared import ProviderType from tools.debug import DebugIssueTool diff --git a/tests/test_xai_provider.py b/tests/test_xai_provider.py index 0b8eb1b..5bdc4a0 100644 --- a/tests/test_xai_provider.py +++ b/tests/test_xai_provider.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from providers.base import ProviderType +from providers.shared import ProviderType from providers.xai import XAIModelProvider @@ -265,7 +265,7 @@ class TestXAIProvider: assert "grok-3-fast" in provider.SUPPORTED_MODELS # Check model configs have required fields - from providers.base import ModelCapabilities + from providers.shared import ModelCapabilities grok4_config = provider.SUPPORTED_MODELS["grok-4"] assert isinstance(grok4_config, ModelCapabilities) diff --git a/tests/transport_helpers.py b/tests/transport_helpers.py index 6c0a889..07fbfe9 100644 --- a/tests/transport_helpers.py +++ b/tests/transport_helpers.py @@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str): transport = inject_transport(monkeypatch, "path/to/cassette.json") """ # Ensure OpenAI provider is registered - always needed for transport injection - from providers.base import ProviderType from providers.openai_provider import OpenAIModelProvider from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType # Always register OpenAI provider for transport tests (API key might be dummy) ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) diff --git a/tools/listmodels.py b/tools/listmodels.py index 4d17062..7bde7f2 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -79,9 +79,9 @@ class ListModelsTool(BaseTool): Returns: Formatted list of models by provider """ - from providers.base import ProviderType from providers.openrouter_registry import OpenRouterModelRegistry from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType output_lines = ["# Available AI Models\n"] @@ -162,8 +162,8 @@ class ListModelsTool(BaseTool): try: # Get OpenRouter provider from registry to properly apply restrictions - from providers.base import ProviderType from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) if provider: diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py index adb77f1..eb3995b 100644 --- a/tools/shared/base_tool.py +++ b/tools/shared/base_tool.py @@ -1341,7 +1341,7 @@ When recommending searches, be specific about what information you need and why # Apply 40MB cap for custom models if needed effective_limit_mb = max_size_mb try: - from providers.base import ProviderType + from providers.shared import ProviderType # ModelCapabilities dataclass has provider field defined if capabilities.provider == ProviderType.CUSTOM: diff --git a/tools/version.py b/tools/version.py index 030be17..3acaf7b 100644 --- a/tools/version.py +++ b/tools/version.py @@ -306,8 +306,8 @@ class VersionTool(BaseTool): # Check for configured providers try: - from providers.base import ProviderType from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType provider_status = [] diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index b10544a..2e3a7f3 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -24,7 +24,7 @@ import logging import os from typing import Optional -from providers.base import ProviderType +from providers.shared import ProviderType logger = logging.getLogger(__name__)