refactor: code cleanup

This commit is contained in:
Fahad
2025-10-02 08:09:44 +04:00
parent 218fbdf49c
commit 182aa627df
49 changed files with 370 additions and 249 deletions

View File

@@ -1,11 +1,12 @@
"""Model provider abstractions for supporting multiple AI providers.""" """Model provider abstractions for supporting multiple AI providers."""
from .base import ModelCapabilities, ModelProvider, ModelResponse from .base import ModelProvider
from .gemini import GeminiModelProvider from .gemini import GeminiModelProvider
from .openai_compatible import OpenAICompatibleProvider from .openai_compatible import OpenAICompatibleProvider
from .openai_provider import OpenAIModelProvider from .openai_provider import OpenAIModelProvider
from .openrouter import OpenRouterProvider from .openrouter import OpenRouterProvider
from .registry import ModelProviderRegistry from .registry import ModelProviderRegistry
from .shared import ModelCapabilities, ModelResponse
__all__ = [ __all__ = [
"ModelProvider", "ModelProvider",

View File

@@ -1,12 +1,10 @@
"""Base model provider interface and data classes.""" """Base interfaces and common behaviour for model providers."""
import base64 import base64
import binascii import binascii
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -14,179 +12,20 @@ if TYPE_CHECKING:
from utils.file_types import IMAGES, get_image_mime_type from utils.file_types import IMAGES, get_image_mime_type
from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProviderType(Enum):
"""Supported model provider types."""
GOOGLE = "google"
OPENAI = "openai"
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"
class TemperatureConstraint(ABC):
"""Abstract base class for temperature constraints."""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Check if temperature is valid."""
pass
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Get nearest valid temperature."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get human-readable description of constraint."""
pass
@abstractmethod
def get_default(self) -> float:
"""Get model's default temperature."""
pass
class FixedTemperatureConstraint(TemperatureConstraint):
"""For models that only support one temperature value (e.g., O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""For models supporting continuous temperature ranges."""
def __init__(self, min_temp: float, max_temp: float, default: float = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""For models supporting only specific temperature values."""
def __init__(self, allowed_values: list[float], default: float = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
"""Create temperature constraint object from configuration string.
Args:
constraint_type: Type of constraint ("fixed", "range", "discrete")
Returns:
TemperatureConstraint object based on configuration
"""
if constraint_type == "fixed":
# Fixed temperature models (O3/O4) only support temperature=1.0
return FixedTemperatureConstraint(1.0)
elif constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
else:
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
@dataclass
class ModelCapabilities:
"""Capabilities and constraints for a specific model."""
provider: ProviderType
model_name: str
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
context_window: int # Total context window size in tokens
max_output_tokens: int # Maximum output tokens per request
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False # Whether model can process images
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
# Additional fields for comprehensive model information
description: str = "" # Human-readable description of the model
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
# JSON mode support (for providers that support structured output)
supports_json_mode: bool = False
# Thinking mode support (for models with thinking capabilities)
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
# Custom model flag (for models that only work with custom endpoints)
is_custom: bool = False # Whether this model requires custom API endpoints
# Temperature constraint object - defines temperature limits and behavior
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)
@dataclass
class ModelResponse:
"""Response from a model provider."""
content: str
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
model_name: str = ""
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
@property
def total_tokens(self) -> int:
"""Get total tokens used."""
return self.usage.get("total_tokens", 0)
class ModelProvider(ABC): class ModelProvider(ABC):
"""Abstract base class for model providers.""" """Defines the contract implemented by every model provider backend.
Subclasses adapt third-party SDKs into the MCP server by exposing
capability metadata, request execution, and token counting through a
consistent interface. Shared helper methods (temperature validation,
alias resolution, image handling, etc.) live here so individual providers
only need to focus on provider-specific details.
"""
# All concrete providers must define their supported models # All concrete providers must define their supported models
SUPPORTED_MODELS: dict[str, Any] = {} SUPPORTED_MODELS: dict[str, Any] = {}

View File

@@ -4,15 +4,15 @@ import logging
import os import os
from typing import Optional from typing import Optional
from .base import ( from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
FixedTemperatureConstraint, FixedTemperatureConstraint,
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint, RangeTemperatureConstraint,
) )
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
# Temperature inference patterns # Temperature inference patterns
_TEMP_UNSUPPORTED_PATTERNS = [ _TEMP_UNSUPPORTED_PATTERNS = [
@@ -30,10 +30,13 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
class CustomProvider(OpenAICompatibleProvider): class CustomProvider(OpenAICompatibleProvider):
"""Custom API provider for local models. """Adapter for self-hosted or local OpenAI-compatible endpoints.
Supports local inference servers like Ollama, vLLM, LM Studio, The provider reuses the :mod:`providers.shared` registry to surface
and any OpenAI-compatible API endpoint. user-defined aliases and capability metadata. It also normalises
Ollama-style version tags (``model:latest``) and enforces the same
restriction policies used by cloud providers, ensuring consistent
behaviour regardless of where the model is hosted.
""" """
FRIENDLY_NAME = "Custom API" FRIENDLY_NAME = "Custom API"

View File

@@ -6,22 +6,24 @@ import threading
import time import time
from typing import Optional from typing import Optional
from .base import ( from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
create_temperature_constraint, create_temperature_constraint,
) )
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider): class DIALModelProvider(OpenAICompatibleProvider):
"""DIAL provider using OpenAI-compatible API. """Client for the DIAL (Data & AI Layer) aggregation service.
DIAL provides access to various AI models through a unified API interface. DIAL exposes several third-party models behind a single OpenAI-compatible
Supports GPT, Claude, Gemini, and other models via DIAL deployments. endpoint. This provider wraps the service, publishes capability metadata
for the known deployments, and centralises retry/backoff settings tailored
to DIAL's latency characteristics.
""" """
FRIENDLY_NAME = "DIAL" FRIENDLY_NAME = "DIAL"

View File

@@ -11,13 +11,24 @@ if TYPE_CHECKING:
from google import genai from google import genai
from google.genai import types from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
create_temperature_constraint,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider): class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation.""" """First-party Gemini integration built on the official Google SDK.
The provider advertises detailed thinking-mode budgets, handles optional
custom endpoints, and performs image pre-processing before forwarding a
request to the Gemini APIs.
"""
# Model configurations using ModelCapabilities objects # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {

View File

@@ -11,21 +11,21 @@ from urllib.parse import urlparse
from openai import OpenAI from openai import OpenAI
from .base import ( from .base import ModelProvider
from .shared import (
ModelCapabilities, ModelCapabilities,
ModelProvider,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
) )
class OpenAICompatibleProvider(ModelProvider): class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API. """Shared implementation for OpenAI API lookalikes.
This includes: The class owns HTTP client configuration (timeouts, proxy hardening,
- Direct OpenAI API custom headers) and normalises the OpenAI SDK responses into
- OpenRouter :class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
- Any other OpenAI-compatible endpoint provide capability metadata and any provider-specific request tweaks.
""" """
DEFAULT_HEADERS = {} DEFAULT_HEADERS = {}

View File

@@ -6,19 +6,24 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
from .base import ( from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
create_temperature_constraint, create_temperature_constraint,
) )
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAIModelProvider(OpenAICompatibleProvider): class OpenAIModelProvider(OpenAICompatibleProvider):
"""Official OpenAI API provider (api.openai.com).""" """Implementation that talks to api.openai.com using rich model metadata.
In addition to the built-in catalogue, the provider can surface models
defined in ``conf/custom_models.json`` (for organisations running their own
OpenAI-compatible gateways) while still respecting restriction policies.
"""
# Model configurations using ModelCapabilities objects # Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = { SUPPORTED_MODELS = {

View File

@@ -4,21 +4,22 @@ import logging
import os import os
from typing import Optional from typing import Optional
from .base import ( from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
RangeTemperatureConstraint, RangeTemperatureConstraint,
) )
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
class OpenRouterProvider(OpenAICompatibleProvider): class OpenRouterProvider(OpenAICompatibleProvider):
"""OpenRouter unified API provider. """Client for OpenRouter's multi-model aggregation service.
OpenRouter provides access to multiple AI models through a single API endpoint. OpenRouter surfaces dozens of upstream vendors. This provider layers alias
See https://openrouter.ai for available models and pricing. resolution, restriction-aware filtering, and sensible capability defaults
on top of the generic OpenAI-compatible plumbing.
""" """
FRIENDLY_NAME = "OpenRouter" FRIENDLY_NAME = "OpenRouter"

View File

@@ -9,7 +9,7 @@ from typing import Optional
# Import handled via importlib.resources.files() calls directly # Import handled via importlib.resources.files() calls directly
from utils.file_utils import read_json_file from utils.file_utils import read_json_file
from .base import ( from .shared import (
ModelCapabilities, ModelCapabilities,
ProviderType, ProviderType,
create_temperature_constraint, create_temperature_constraint,
@@ -17,7 +17,13 @@ from .base import (
class OpenRouterModelRegistry: class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases.""" """Loads and validates the OpenRouter/custom model catalogue.
The registry parses ``conf/custom_models.json`` (or an override supplied via
environment variable), builds case-insensitive alias maps, and exposes
:class:`~providers.shared.ModelCapabilities` objects used by several
providers.
"""
def __init__(self, config_path: Optional[str] = None): def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry. """Initialize the registry.
@@ -263,6 +269,11 @@ class OpenRouterModelRegistry:
# Registry now returns ModelCapabilities directly # Registry now returns ModelCapabilities directly
return self.resolve(name_or_alias) return self.resolve(name_or_alias)
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Backward-compatible wrapper used by providers and older tests."""
return self.resolve(name_or_alias)
def list_models(self) -> list[str]: def list_models(self) -> list[str]:
"""List all available model names.""" """List all available model names."""
return list(self.model_map.keys()) return list(self.model_map.keys())

View File

@@ -4,14 +4,20 @@ import logging
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from .base import ModelProvider, ProviderType from .base import ModelProvider
from .shared import ProviderType
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
class ModelProviderRegistry: class ModelProviderRegistry:
"""Registry for managing model providers.""" """Singleton that caches provider instances and coordinates priority order.
Responsibilities include resolving API keys from the environment, lazily
instantiating providers, and choosing the best provider for a model based
on restriction policies and provider priority.
"""
_instance = None _instance = None

View File

@@ -0,0 +1,23 @@
"""Shared data structures and helpers for model providers."""
from .model_capabilities import ModelCapabilities
from .model_response import ModelResponse
from .provider_type import ProviderType
from .temperature import (
DiscreteTemperatureConstraint,
FixedTemperatureConstraint,
RangeTemperatureConstraint,
TemperatureConstraint,
create_temperature_constraint,
)
__all__ = [
"ModelCapabilities",
"ModelResponse",
"ProviderType",
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
"create_temperature_constraint",
]

View File

@@ -0,0 +1,34 @@
"""Dataclass describing the feature set of a model exposed by a provider."""
from dataclasses import dataclass, field
from .provider_type import ProviderType
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
__all__ = ["ModelCapabilities"]
@dataclass
class ModelCapabilities:
"""Static capabilities and constraints for a provider-managed model."""
provider: ProviderType
model_name: str
friendly_name: str
context_window: int
max_output_tokens: int
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False
max_image_size_mb: float = 0.0
supports_temperature: bool = True
description: str = ""
aliases: list[str] = field(default_factory=list)
supports_json_mode: bool = False
max_thinking_tokens: int = 0
is_custom: bool = False
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)

View File

@@ -0,0 +1,26 @@
"""Dataclass used to normalise provider SDK responses."""
from dataclasses import dataclass, field
from typing import Any
from .provider_type import ProviderType
__all__ = ["ModelResponse"]
@dataclass
class ModelResponse:
"""Portable representation of a provider completion."""
content: str
usage: dict[str, int] = field(default_factory=dict)
model_name: str = ""
friendly_name: str = ""
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict)
@property
def total_tokens(self) -> int:
"""Return the total token count if the provider reported usage data."""
return self.usage.get("total_tokens", 0)

View File

@@ -0,0 +1,16 @@
"""Enumeration describing which backend owns a given model."""
from enum import Enum
__all__ = ["ProviderType"]
class ProviderType(Enum):
"""Canonical identifiers for every supported provider backend."""
GOOGLE = "google"
OPENAI = "openai"
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"

View File

@@ -0,0 +1,121 @@
"""Helper types for validating model temperature parameters."""
from abc import ABC, abstractmethod
from typing import Optional
__all__ = [
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
"create_temperature_constraint",
]
class TemperatureConstraint(ABC):
"""Contract for temperature validation used by `ModelCapabilities`.
Concrete providers describe their temperature behaviour by creating
subclasses that expose three operations:
* `validate` decide whether a requested temperature is acceptable.
* `get_corrected_value` coerce out-of-range values into a safe default.
* `get_description` provide a human readable error message for users.
Providers call these hooks before sending traffic to the underlying API so
that unsupported temperatures never reach the remote service.
"""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Return ``True`` when the temperature may be sent to the backend."""
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Return a valid substitute for an out-of-range temperature."""
@abstractmethod
def get_description(self) -> str:
"""Describe the acceptable range to include in error messages."""
@abstractmethod
def get_default(self) -> float:
"""Return the default temperature for the model."""
class FixedTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that enforce an exact temperature (for example O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""Constraint for providers that expose a continuous min/max temperature range."""
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that permit a discrete list of temperature values."""
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
"""Factory that yields the appropriate constraint for a model configuration.
The JSON configuration stored in ``conf/custom_models.json`` references this
helper via human-readable strings. Providers feed those values into this
function so that runtime logic can rely on strongly typed constraint
objects.
"""
if constraint_type == "fixed":
# Fixed temperature models (O3/O4) only support temperature=1.0
return FixedTemperatureConstraint(1.0)
if constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)

View File

@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tools.models import ToolModelCategory from tools.models import ToolModelCategory
from .base import ( from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities, ModelCapabilities,
ModelResponse, ModelResponse,
ProviderType, ProviderType,
create_temperature_constraint, create_temperature_constraint,
) )
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class XAIModelProvider(OpenAICompatibleProvider): class XAIModelProvider(OpenAICompatibleProvider):
"""X.AI GROK API provider (api.x.ai).""" """Integration for X.AI's GROK models exposed over an OpenAI-style API.
Publishes capability metadata for the officially supported deployments and
maps tool-category preferences to the appropriate GROK model.
"""
FRIENDLY_NAME = "X.AI" FRIENDLY_NAME = "X.AI"

View File

@@ -412,12 +412,12 @@ def configure_providers():
value = os.getenv(key) value = os.getenv(key)
logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}") logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}")
from providers import ModelProviderRegistry from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider from providers.custom import CustomProvider
from providers.dial import DIALModelProvider from providers.dial import DIALModelProvider
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider from providers.openrouter import OpenRouterProvider
from providers.shared import ProviderType
from providers.xai import XAIModelProvider from providers.xai import XAIModelProvider
from utils.model_restrictions import get_restriction_service from utils.model_restrictions import get_restriction_service

View File

@@ -34,9 +34,9 @@ if sys.platform == "win32":
# Register providers for all tests # Register providers for all tests
from providers import ModelProviderRegistry # noqa: E402 from providers import ModelProviderRegistry # noqa: E402
from providers.base import ProviderType # noqa: E402
from providers.gemini import GeminiModelProvider # noqa: E402 from providers.gemini import GeminiModelProvider # noqa: E402
from providers.openai_provider import OpenAIModelProvider # noqa: E402 from providers.openai_provider import OpenAIModelProvider # noqa: E402
from providers.shared import ProviderType # noqa: E402
from providers.xai import XAIModelProvider # noqa: E402 from providers.xai import XAIModelProvider # noqa: E402
# Register providers at test startup # Register providers at test startup
@@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch):
return return
# Ensure providers are registered (in case other tests cleared the registry) # Ensure providers are registered (in case other tests cleared the registry)
from providers.base import ProviderType from providers.shared import ProviderType
registry = ModelProviderRegistry() registry = ModelProviderRegistry()
@@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch):
return False return False
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode) monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
@pytest.fixture(autouse=True)
def clear_model_restriction_env(monkeypatch):
"""Ensure per-test isolation from user-defined model restriction env vars."""
restriction_vars = [
"OPENAI_ALLOWED_MODELS",
"GOOGLE_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
]
for var in restriction_vars:
monkeypatch.delenv(var, raising=False)

View File

@@ -2,7 +2,7 @@
from unittest.mock import Mock from unittest.mock import Mock
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576): def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):

View File

@@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti
import os import os
from unittest.mock import patch from unittest.mock import patch
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService from utils.model_restrictions import ModelRestrictionService

View File

@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.analyze import AnalyzeTool from tools.analyze import AnalyzeTool
from tools.chat import ChatTool from tools.chat import ChatTool
from tools.debug import DebugIssueTool from tools.debug import DebugIssueTool

View File

@@ -6,8 +6,8 @@ from unittest.mock import patch
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
@pytest.mark.no_mock_provider @pytest.mark.no_mock_provider

View File

@@ -4,8 +4,8 @@ import os
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.models import ToolModelCategory from tools.models import ToolModelCategory

View File

@@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService from utils.model_restrictions import ModelRestrictionService

View File

@@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry_class.return_value = mock_registry mock_registry_class.return_value = mock_registry
# Mock get_model_config to return our test model # Mock get_model_config to return our test model
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities( test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,
@@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry_class.return_value = mock_registry mock_registry_class.return_value = mock_registry
# Mock get_model_config to return a model that supports temperature # Mock get_model_config to return a model that supports temperature
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities( test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,
@@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry = Mock() mock_registry = Mock()
mock_registry_class.return_value = mock_registry mock_registry_class.return_value = mock_registry
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities( test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI, provider=ProviderType.OPENAI,

View File

@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers import ModelProviderRegistry from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider from providers.custom import CustomProvider
from providers.shared import ProviderType
class TestCustomProvider: class TestCustomProvider:

View File

@@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.dial import DIALModelProvider from providers.dial import DIALModelProvider
from providers.shared import ProviderType
class TestDIALProvider: class TestDIALProvider:

View File

@@ -8,7 +8,8 @@ from unittest.mock import Mock, patch
import pytest import pytest
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType from providers.base import ModelProvider
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider): class MinimalTestProvider(ModelProvider):

View File

@@ -9,8 +9,8 @@ from unittest.mock import Mock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
class TestIntelligentFallback: class TestIntelligentFallback:

View File

@@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored():
mock_registry = Mock() mock_registry = Mock()
mock_registry_class.return_value = mock_registry mock_registry_class.return_value = mock_registry
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
# This is what the user configured in their custom_models.json # This is what the user configured in their custom_models.json
custom_config = ModelCapabilities( custom_config = ModelCapabilities(

View File

@@ -5,8 +5,9 @@ import os
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from providers.base import ModelProvider, ProviderType from providers.base import ModelProvider
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.listmodels import ListModelsTool from tools.listmodels import ListModelsTool

View File

@@ -214,7 +214,7 @@ class TestModelEnumeration:
# Rebuild the provider registry with OpenRouter registered # Rebuild the provider registry with OpenRouter registered
ModelProviderRegistry._instance = None ModelProviderRegistry._instance = None
from providers.base import ProviderType from providers.shared import ProviderType
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)

View File

@@ -9,8 +9,8 @@ This test specifically targets the bug where:
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider from providers.openrouter import OpenRouterProvider
from providers.shared import ProviderType
from tools.consensus import ConsensusTool from tools.consensus import ConsensusTool

View File

@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService from utils.model_restrictions import ModelRestrictionService

View File

@@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems.
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from providers.base import ProviderType from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService from utils.model_restrictions import ModelRestrictionService

View File

@@ -3,8 +3,8 @@
import os import os
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestOpenAIProvider: class TestOpenAIProvider:

View File

@@ -5,9 +5,9 @@ from unittest.mock import Mock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider from providers.openrouter import OpenRouterProvider
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
class TestOpenRouterProvider: class TestOpenRouterProvider:

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest import pytest
from providers.base import ModelCapabilities, ProviderType
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
from providers.shared import ModelCapabilities, ProviderType
class TestOpenRouterModelRegistry: class TestOpenRouterModelRegistry:
@@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry:
def test_model_with_all_capabilities(self): def test_model_with_all_capabilities(self):
"""Test model with all capability flags.""" """Test model with all capability flags."""
from providers.base import create_temperature_constraint from providers.shared import create_temperature_constraint
caps = ModelCapabilities( caps = ModelCapabilities(
provider=ProviderType.OPENROUTER, provider=ProviderType.OPENROUTER,

View File

@@ -13,8 +13,8 @@ from unittest.mock import Mock
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.chat import ChatTool from tools.chat import ChatTool
from tools.shared.base_models import ToolRequest from tools.shared.base_models import ToolRequest

View File

@@ -10,9 +10,9 @@ from unittest.mock import Mock, patch
import pytest import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestProviderUTF8Encoding(unittest.TestCase): class TestProviderUTF8Encoding(unittest.TestCase):
@@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
def test_model_response_utf8_serialization(self): def test_model_response_utf8_serialization(self):
"""Test UTF-8 serialization of model responses.""" """Test UTF-8 serialization of model responses."""
from providers.base import ModelResponse from providers.shared import ModelResponse
response = ModelResponse( response = ModelResponse(
content="Development successful! Code generated successfully. 🎉✅", content="Development successful! Code generated successfully. 🎉✅",

View File

@@ -6,9 +6,9 @@ from unittest.mock import Mock, patch
import pytest import pytest
from providers import ModelProviderRegistry, ModelResponse from providers import ModelProviderRegistry, ModelResponse
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestModelProviderRegistry: class TestModelProviderRegistry:

View File

@@ -185,7 +185,7 @@ class TestSupportedModelsAliases:
for provider in providers: for provider in providers:
for model_name, config in provider.SUPPORTED_MODELS.items(): for model_name, config in provider.SUPPORTED_MODELS.items():
# All values must be ModelCapabilities objects, not strings or dicts # All values must be ModelCapabilities objects, not strings or dicts
from providers.base import ModelCapabilities from providers.shared import ModelCapabilities
assert isinstance(config, ModelCapabilities), ( assert isinstance(config, ModelCapabilities), (
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] " f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "

View File

@@ -10,8 +10,8 @@ import os
import pytest import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.debug import DebugIssueTool from tools.debug import DebugIssueTool

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from providers.base import ProviderType from providers.shared import ProviderType
from providers.xai import XAIModelProvider from providers.xai import XAIModelProvider
@@ -265,7 +265,7 @@ class TestXAIProvider:
assert "grok-3-fast" in provider.SUPPORTED_MODELS assert "grok-3-fast" in provider.SUPPORTED_MODELS
# Check model configs have required fields # Check model configs have required fields
from providers.base import ModelCapabilities from providers.shared import ModelCapabilities
grok4_config = provider.SUPPORTED_MODELS["grok-4"] grok4_config = provider.SUPPORTED_MODELS["grok-4"]
assert isinstance(grok4_config, ModelCapabilities) assert isinstance(grok4_config, ModelCapabilities)

View File

@@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str):
transport = inject_transport(monkeypatch, "path/to/cassette.json") transport = inject_transport(monkeypatch, "path/to/cassette.json")
""" """
# Ensure OpenAI provider is registered - always needed for transport injection # Ensure OpenAI provider is registered - always needed for transport injection
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider from providers.openai_provider import OpenAIModelProvider
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
# Always register OpenAI provider for transport tests (API key might be dummy) # Always register OpenAI provider for transport tests (API key might be dummy)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider) ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)

View File

@@ -79,9 +79,9 @@ class ListModelsTool(BaseTool):
Returns: Returns:
Formatted list of models by provider Formatted list of models by provider
""" """
from providers.base import ProviderType
from providers.openrouter_registry import OpenRouterModelRegistry from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
output_lines = ["# Available AI Models\n"] output_lines = ["# Available AI Models\n"]
@@ -162,8 +162,8 @@ class ListModelsTool(BaseTool):
try: try:
# Get OpenRouter provider from registry to properly apply restrictions # Get OpenRouter provider from registry to properly apply restrictions
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider: if provider:

View File

@@ -1341,7 +1341,7 @@ When recommending searches, be specific about what information you need and why
# Apply 40MB cap for custom models if needed # Apply 40MB cap for custom models if needed
effective_limit_mb = max_size_mb effective_limit_mb = max_size_mb
try: try:
from providers.base import ProviderType from providers.shared import ProviderType
# ModelCapabilities dataclass has provider field defined # ModelCapabilities dataclass has provider field defined
if capabilities.provider == ProviderType.CUSTOM: if capabilities.provider == ProviderType.CUSTOM:

View File

@@ -306,8 +306,8 @@ class VersionTool(BaseTool):
# Check for configured providers # Check for configured providers
try: try:
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider_status = [] provider_status = []

View File

@@ -24,7 +24,7 @@ import logging
import os import os
from typing import Optional from typing import Optional
from providers.base import ProviderType from providers.shared import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)