refactor: code cleanup

This commit is contained in:
Fahad
2025-10-02 08:09:44 +04:00
parent 218fbdf49c
commit 182aa627df
49 changed files with 370 additions and 249 deletions

View File

@@ -1,11 +1,12 @@
"""Model provider abstractions for supporting multiple AI providers."""
from .base import ModelCapabilities, ModelProvider, ModelResponse
from .base import ModelProvider
from .gemini import GeminiModelProvider
from .openai_compatible import OpenAICompatibleProvider
from .openai_provider import OpenAIModelProvider
from .openrouter import OpenRouterProvider
from .registry import ModelProviderRegistry
from .shared import ModelCapabilities, ModelResponse
__all__ = [
"ModelProvider",

View File

@@ -1,12 +1,10 @@
"""Base model provider interface and data classes."""
"""Base interfaces and common behaviour for model providers."""
import base64
import binascii
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
@@ -14,179 +12,20 @@ if TYPE_CHECKING:
from utils.file_types import IMAGES, get_image_mime_type
from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__)
class ProviderType(Enum):
"""Supported model provider types."""
GOOGLE = "google"
OPENAI = "openai"
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"
class TemperatureConstraint(ABC):
"""Abstract base class for temperature constraints."""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Check if temperature is valid."""
pass
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Get nearest valid temperature."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get human-readable description of constraint."""
pass
@abstractmethod
def get_default(self) -> float:
"""Get model's default temperature."""
pass
class FixedTemperatureConstraint(TemperatureConstraint):
"""For models that only support one temperature value (e.g., O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""For models supporting continuous temperature ranges."""
def __init__(self, min_temp: float, max_temp: float, default: float = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""For models supporting only specific temperature values."""
def __init__(self, allowed_values: list[float], default: float = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
"""Create temperature constraint object from configuration string.
Args:
constraint_type: Type of constraint ("fixed", "range", "discrete")
Returns:
TemperatureConstraint object based on configuration
"""
if constraint_type == "fixed":
# Fixed temperature models (O3/O4) only support temperature=1.0
return FixedTemperatureConstraint(1.0)
elif constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
else:
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
@dataclass
class ModelCapabilities:
"""Capabilities and constraints for a specific model."""
provider: ProviderType
model_name: str
friendly_name: str # Human-friendly name like "Gemini" or "OpenAI"
context_window: int # Total context window size in tokens
max_output_tokens: int # Maximum output tokens per request
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False # Whether model can process images
max_image_size_mb: float = 0.0 # Maximum total size for all images in MB
supports_temperature: bool = True # Whether model accepts temperature parameter in API calls
# Additional fields for comprehensive model information
description: str = "" # Human-readable description of the model
aliases: list[str] = field(default_factory=list) # Alternative names/shortcuts for the model
# JSON mode support (for providers that support structured output)
supports_json_mode: bool = False
# Thinking mode support (for models with thinking capabilities)
max_thinking_tokens: int = 0 # Maximum thinking tokens for extended reasoning models
# Custom model flag (for models that only work with custom endpoints)
is_custom: bool = False # Whether this model requires custom API endpoints
# Temperature constraint object - defines temperature limits and behavior
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)
@dataclass
class ModelResponse:
"""Response from a model provider."""
content: str
usage: dict[str, int] = field(default_factory=dict) # input_tokens, output_tokens, total_tokens
model_name: str = ""
friendly_name: str = "" # Human-friendly name like "Gemini" or "OpenAI"
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict) # Provider-specific metadata
@property
def total_tokens(self) -> int:
"""Get total tokens used."""
return self.usage.get("total_tokens", 0)
class ModelProvider(ABC):
"""Abstract base class for model providers."""
"""Defines the contract implemented by every model provider backend.
Subclasses adapt third-party SDKs into the MCP server by exposing
capability metadata, request execution, and token counting through a
consistent interface. Shared helper methods (temperature validation,
alias resolution, image handling, etc.) live here so individual providers
only need to focus on provider-specific details.
"""
# All concrete providers must define their supported models
SUPPORTED_MODELS: dict[str, Any] = {}

View File

@@ -4,15 +4,15 @@ import logging
import os
from typing import Optional
from .base import (
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
FixedTemperatureConstraint,
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
# Temperature inference patterns
_TEMP_UNSUPPORTED_PATTERNS = [
@@ -30,10 +30,13 @@ _TEMP_UNSUPPORTED_KEYWORDS = [
class CustomProvider(OpenAICompatibleProvider):
"""Custom API provider for local models.
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
Supports local inference servers like Ollama, vLLM, LM Studio,
and any OpenAI-compatible API endpoint.
The provider reuses the :mod:`providers.shared` registry to surface
user-defined aliases and capability metadata. It also normalises
Ollama-style version tags (``model:latest``) and enforces the same
restriction policies used by cloud providers, ensuring consistent
behaviour regardless of where the model is hosted.
"""
FRIENDLY_NAME = "Custom API"

View File

@@ -6,22 +6,24 @@ import threading
import time
from typing import Optional
from .base import (
from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
create_temperature_constraint,
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider):
"""DIAL provider using OpenAI-compatible API.
"""Client for the DIAL (Data & AI Layer) aggregation service.
DIAL provides access to various AI models through a unified API interface.
Supports GPT, Claude, Gemini, and other models via DIAL deployments.
DIAL exposes several third-party models behind a single OpenAI-compatible
endpoint. This provider wraps the service, publishes capability metadata
for the known deployments, and centralises retry/backoff settings tailored
to DIAL's latency characteristics.
"""
FRIENDLY_NAME = "DIAL"

View File

@@ -11,13 +11,24 @@ if TYPE_CHECKING:
from google import genai
from google.genai import types
from .base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType, create_temperature_constraint
from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
create_temperature_constraint,
)
logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider):
"""Google Gemini model provider implementation."""
"""First-party Gemini integration built on the official Google SDK.
The provider advertises detailed thinking-mode budgets, handles optional
custom endpoints, and performs image pre-processing before forwarding a
request to the Gemini APIs.
"""
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {

View File

@@ -11,21 +11,21 @@ from urllib.parse import urlparse
from openai import OpenAI
from .base import (
from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelProvider,
ModelResponse,
ProviderType,
)
class OpenAICompatibleProvider(ModelProvider):
"""Base class for any provider using an OpenAI-compatible API.
"""Shared implementation for OpenAI API lookalikes.
This includes:
- Direct OpenAI API
- OpenRouter
- Any other OpenAI-compatible endpoint
The class owns HTTP client configuration (timeouts, proxy hardening,
custom headers) and normalises the OpenAI SDK responses into
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
provide capability metadata and any provider-specific request tweaks.
"""
DEFAULT_HEADERS = {}

View File

@@ -6,19 +6,24 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import (
from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
create_temperature_constraint,
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class OpenAIModelProvider(OpenAICompatibleProvider):
"""Official OpenAI API provider (api.openai.com)."""
"""Implementation that talks to api.openai.com using rich model metadata.
In addition to the built-in catalogue, the provider can surface models
defined in ``conf/custom_models.json`` (for organisations running their own
OpenAI-compatible gateways) while still respecting restriction policies.
"""
# Model configurations using ModelCapabilities objects
SUPPORTED_MODELS = {

View File

@@ -4,21 +4,22 @@ import logging
import os
from typing import Optional
from .base import (
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
class OpenRouterProvider(OpenAICompatibleProvider):
"""OpenRouter unified API provider.
"""Client for OpenRouter's multi-model aggregation service.
OpenRouter provides access to multiple AI models through a single API endpoint.
See https://openrouter.ai for available models and pricing.
OpenRouter surfaces dozens of upstream vendors. This provider layers alias
resolution, restriction-aware filtering, and sensible capability defaults
on top of the generic OpenAI-compatible plumbing.
"""
FRIENDLY_NAME = "OpenRouter"

View File

@@ -9,7 +9,7 @@ from typing import Optional
# Import handled via importlib.resources.files() calls directly
from utils.file_utils import read_json_file
from .base import (
from .shared import (
ModelCapabilities,
ProviderType,
create_temperature_constraint,
@@ -17,7 +17,13 @@ from .base import (
class OpenRouterModelRegistry:
"""Registry for managing OpenRouter model configurations and aliases."""
"""Loads and validates the OpenRouter/custom model catalogue.
The registry parses ``conf/custom_models.json`` (or an override supplied via
environment variable), builds case-insensitive alias maps, and exposes
:class:`~providers.shared.ModelCapabilities` objects used by several
providers.
"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry.
@@ -263,6 +269,11 @@ class OpenRouterModelRegistry:
# Registry now returns ModelCapabilities directly
return self.resolve(name_or_alias)
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Backward-compatible wrapper used by providers and older tests."""
return self.resolve(name_or_alias)
def list_models(self) -> list[str]:
"""List all available model names."""
return list(self.model_map.keys())

View File

@@ -4,14 +4,20 @@ import logging
import os
from typing import TYPE_CHECKING, Optional
from .base import ModelProvider, ProviderType
from .base import ModelProvider
from .shared import ProviderType
if TYPE_CHECKING:
from tools.models import ToolModelCategory
class ModelProviderRegistry:
"""Registry for managing model providers."""
"""Singleton that caches provider instances and coordinates priority order.
Responsibilities include resolving API keys from the environment, lazily
instantiating providers, and choosing the best provider for a model based
on restriction policies and provider priority.
"""
_instance = None

View File

@@ -0,0 +1,23 @@
"""Shared data structures and helpers for model providers."""
from .model_capabilities import ModelCapabilities
from .model_response import ModelResponse
from .provider_type import ProviderType
from .temperature import (
DiscreteTemperatureConstraint,
FixedTemperatureConstraint,
RangeTemperatureConstraint,
TemperatureConstraint,
create_temperature_constraint,
)
__all__ = [
"ModelCapabilities",
"ModelResponse",
"ProviderType",
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
"create_temperature_constraint",
]

View File

@@ -0,0 +1,34 @@
"""Dataclass describing the feature set of a model exposed by a provider."""
from dataclasses import dataclass, field
from .provider_type import ProviderType
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
__all__ = ["ModelCapabilities"]
@dataclass
class ModelCapabilities:
"""Static capabilities and constraints for a provider-managed model."""
provider: ProviderType
model_name: str
friendly_name: str
context_window: int
max_output_tokens: int
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False
max_image_size_mb: float = 0.0
supports_temperature: bool = True
description: str = ""
aliases: list[str] = field(default_factory=list)
supports_json_mode: bool = False
max_thinking_tokens: int = 0
is_custom: bool = False
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)

View File

@@ -0,0 +1,26 @@
"""Dataclass used to normalise provider SDK responses."""
from dataclasses import dataclass, field
from typing import Any
from .provider_type import ProviderType
__all__ = ["ModelResponse"]
@dataclass
class ModelResponse:
"""Portable representation of a provider completion."""
content: str
usage: dict[str, int] = field(default_factory=dict)
model_name: str = ""
friendly_name: str = ""
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict)
@property
def total_tokens(self) -> int:
"""Return the total token count if the provider reported usage data."""
return self.usage.get("total_tokens", 0)

View File

@@ -0,0 +1,16 @@
"""Enumeration describing which backend owns a given model."""
from enum import Enum
__all__ = ["ProviderType"]
class ProviderType(Enum):
"""Canonical identifiers for every supported provider backend."""
GOOGLE = "google"
OPENAI = "openai"
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"

View File

@@ -0,0 +1,121 @@
"""Helper types for validating model temperature parameters."""
from abc import ABC, abstractmethod
from typing import Optional
__all__ = [
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
"create_temperature_constraint",
]
class TemperatureConstraint(ABC):
"""Contract for temperature validation used by `ModelCapabilities`.
Concrete providers describe their temperature behaviour by creating
subclasses that expose three operations:
* `validate` decide whether a requested temperature is acceptable.
* `get_corrected_value` coerce out-of-range values into a safe default.
* `get_description` provide a human readable error message for users.
Providers call these hooks before sending traffic to the underlying API so
that unsupported temperatures never reach the remote service.
"""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Return ``True`` when the temperature may be sent to the backend."""
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Return a valid substitute for an out-of-range temperature."""
@abstractmethod
def get_description(self) -> str:
"""Describe the acceptable range to include in error messages."""
@abstractmethod
def get_default(self) -> float:
"""Return the default temperature for the model."""
class FixedTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that enforce an exact temperature (for example O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""Constraint for providers that expose a continuous min/max temperature range."""
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that permit a discrete list of temperature values."""
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp
def create_temperature_constraint(constraint_type: str) -> TemperatureConstraint:
"""Factory that yields the appropriate constraint for a model configuration.
The JSON configuration stored in ``conf/custom_models.json`` references this
helper via human-readable strings. Providers feed those values into this
function so that runtime logic can rely on strongly typed constraint
objects.
"""
if constraint_type == "fixed":
# Fixed temperature models (O3/O4) only support temperature=1.0
return FixedTemperatureConstraint(1.0)
if constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)

View File

@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .base import (
from .openai_compatible import OpenAICompatibleProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
create_temperature_constraint,
)
from .openai_compatible import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
class XAIModelProvider(OpenAICompatibleProvider):
"""X.AI GROK API provider (api.x.ai)."""
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
Publishes capability metadata for the officially supported deployments and
maps tool-category preferences to the appropriate GROK model.
"""
FRIENDLY_NAME = "X.AI"

View File

@@ -412,12 +412,12 @@ def configure_providers():
value = os.getenv(key)
logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}")
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider
from providers.dial import DIALModelProvider
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
from providers.shared import ProviderType
from providers.xai import XAIModelProvider
from utils.model_restrictions import get_restriction_service

View File

@@ -34,9 +34,9 @@ if sys.platform == "win32":
# Register providers for all tests
from providers import ModelProviderRegistry # noqa: E402
from providers.base import ProviderType # noqa: E402
from providers.gemini import GeminiModelProvider # noqa: E402
from providers.openai_provider import OpenAIModelProvider # noqa: E402
from providers.shared import ProviderType # noqa: E402
from providers.xai import XAIModelProvider # noqa: E402
# Register providers at test startup
@@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch):
return
# Ensure providers are registered (in case other tests cleared the registry)
from providers.base import ProviderType
from providers.shared import ProviderType
registry = ModelProviderRegistry()
@@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch):
return False
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
@pytest.fixture(autouse=True)
def clear_model_restriction_env(monkeypatch):
"""Ensure per-test isolation from user-defined model restriction env vars."""
restriction_vars = [
"OPENAI_ALLOWED_MODELS",
"GOOGLE_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
]
for var in restriction_vars:
monkeypatch.delenv(var, raising=False)

View File

@@ -2,7 +2,7 @@
from unittest.mock import Mock
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):

View File

@@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti
import os
from unittest.mock import patch
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService

View File

@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.analyze import AnalyzeTool
from tools.chat import ChatTool
from tools.debug import DebugIssueTool

View File

@@ -6,8 +6,8 @@ from unittest.mock import patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
@pytest.mark.no_mock_provider

View File

@@ -4,8 +4,8 @@ import os
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.models import ToolModelCategory

View File

@@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService

View File

@@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry_class.return_value = mock_registry
# Mock get_model_config to return our test model
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI,
@@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry_class.return_value = mock_registry
# Mock get_model_config to return a model that supports temperature
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI,
@@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix:
mock_registry = Mock()
mock_registry_class.return_value = mock_registry
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
test_capabilities = ModelCapabilities(
provider=ProviderType.OPENAI,

View File

@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest
from providers import ModelProviderRegistry
from providers.base import ProviderType
from providers.custom import CustomProvider
from providers.shared import ProviderType
class TestCustomProvider:

View File

@@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.dial import DIALModelProvider
from providers.shared import ProviderType
class TestDIALProvider:

View File

@@ -8,7 +8,8 @@ from unittest.mock import Mock, patch
import pytest
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
from providers.base import ModelProvider
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
class MinimalTestProvider(ModelProvider):

View File

@@ -9,8 +9,8 @@ from unittest.mock import Mock, patch
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
class TestIntelligentFallback:

View File

@@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored():
mock_registry = Mock()
mock_registry_class.return_value = mock_registry
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
# This is what the user configured in their custom_models.json
custom_config = ModelCapabilities(

View File

@@ -5,8 +5,9 @@ import os
import unittest
from unittest.mock import MagicMock, patch
from providers.base import ModelProvider, ProviderType
from providers.base import ModelProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.listmodels import ListModelsTool

View File

@@ -214,7 +214,7 @@ class TestModelEnumeration:
# Rebuild the provider registry with OpenRouter registered
ModelProviderRegistry._instance = None
from providers.base import ProviderType
from providers.shared import ProviderType
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)

View File

@@ -9,8 +9,8 @@ This test specifically targets the bug where:
from unittest.mock import Mock, patch
from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider
from providers.shared import ProviderType
from tools.consensus import ConsensusTool

View File

@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService

View File

@@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems.
from unittest.mock import MagicMock, patch
from providers.base import ProviderType
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService

View File

@@ -3,8 +3,8 @@
import os
from unittest.mock import MagicMock, patch
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestOpenAIProvider:

View File

@@ -5,9 +5,9 @@ from unittest.mock import Mock, patch
import pytest
from providers.base import ProviderType
from providers.openrouter import OpenRouterProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
class TestOpenRouterProvider:

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
from providers.base import ModelCapabilities, ProviderType
from providers.openrouter_registry import OpenRouterModelRegistry
from providers.shared import ModelCapabilities, ProviderType
class TestOpenRouterModelRegistry:
@@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry:
def test_model_with_all_capabilities(self):
"""Test model with all capability flags."""
from providers.base import create_temperature_constraint
from providers.shared import create_temperature_constraint
caps = ModelCapabilities(
provider=ProviderType.OPENROUTER,

View File

@@ -13,8 +13,8 @@ from unittest.mock import Mock
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.chat import ChatTool
from tools.shared.base_models import ToolRequest

View File

@@ -10,9 +10,9 @@ from unittest.mock import Mock, patch
import pytest
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestProviderUTF8Encoding(unittest.TestCase):
@@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
def test_model_response_utf8_serialization(self):
"""Test UTF-8 serialization of model responses."""
from providers.base import ModelResponse
from providers.shared import ModelResponse
response = ModelResponse(
content="Development successful! Code generated successfully. 🎉✅",

View File

@@ -6,9 +6,9 @@ from unittest.mock import Mock, patch
import pytest
from providers import ModelProviderRegistry, ModelResponse
from providers.base import ProviderType
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.shared import ProviderType
class TestModelProviderRegistry:

View File

@@ -185,7 +185,7 @@ class TestSupportedModelsAliases:
for provider in providers:
for model_name, config in provider.SUPPORTED_MODELS.items():
# All values must be ModelCapabilities objects, not strings or dicts
from providers.base import ModelCapabilities
from providers.shared import ModelCapabilities
assert isinstance(config, ModelCapabilities), (
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "

View File

@@ -10,8 +10,8 @@ import os
import pytest
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from tools.debug import DebugIssueTool

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from providers.base import ProviderType
from providers.shared import ProviderType
from providers.xai import XAIModelProvider
@@ -265,7 +265,7 @@ class TestXAIProvider:
assert "grok-3-fast" in provider.SUPPORTED_MODELS
# Check model configs have required fields
from providers.base import ModelCapabilities
from providers.shared import ModelCapabilities
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
assert isinstance(grok4_config, ModelCapabilities)

View File

@@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str):
transport = inject_transport(monkeypatch, "path/to/cassette.json")
"""
# Ensure OpenAI provider is registered - always needed for transport injection
from providers.base import ProviderType
from providers.openai_provider import OpenAIModelProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
# Always register OpenAI provider for transport tests (API key might be dummy)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)

View File

@@ -79,9 +79,9 @@ class ListModelsTool(BaseTool):
Returns:
Formatted list of models by provider
"""
from providers.base import ProviderType
from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
output_lines = ["# Available AI Models\n"]
@@ -162,8 +162,8 @@ class ListModelsTool(BaseTool):
try:
# Get OpenRouter provider from registry to properly apply restrictions
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider:

View File

@@ -1341,7 +1341,7 @@ When recommending searches, be specific about what information you need and why
# Apply 40MB cap for custom models if needed
effective_limit_mb = max_size_mb
try:
from providers.base import ProviderType
from providers.shared import ProviderType
# ModelCapabilities dataclass has provider field defined
if capabilities.provider == ProviderType.CUSTOM:

View File

@@ -306,8 +306,8 @@ class VersionTool(BaseTool):
# Check for configured providers
try:
from providers.base import ProviderType
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider_status = []

View File

@@ -24,7 +24,7 @@ import logging
import os
from typing import Optional
from providers.base import ProviderType
from providers.shared import ProviderType
logger = logging.getLogger(__name__)