refactor: code cleanup
This commit is contained in:
@@ -34,9 +34,9 @@ if sys.platform == "win32":
|
||||
|
||||
# Register providers for all tests
|
||||
from providers import ModelProviderRegistry # noqa: E402
|
||||
from providers.base import ProviderType # noqa: E402
|
||||
from providers.gemini import GeminiModelProvider # noqa: E402
|
||||
from providers.openai_provider import OpenAIModelProvider # noqa: E402
|
||||
from providers.shared import ProviderType # noqa: E402
|
||||
from providers.xai import XAIModelProvider # noqa: E402
|
||||
|
||||
# Register providers at test startup
|
||||
@@ -109,7 +109,7 @@ def mock_provider_availability(request, monkeypatch):
|
||||
return
|
||||
|
||||
# Ensure providers are registered (in case other tests cleared the registry)
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
registry = ModelProviderRegistry()
|
||||
|
||||
@@ -197,3 +197,19 @@ def mock_provider_availability(request, monkeypatch):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_model_restriction_env(monkeypatch):
|
||||
"""Ensure per-test isolation from user-defined model restriction env vars."""
|
||||
|
||||
restriction_vars = [
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
]
|
||||
|
||||
for var in restriction_vars:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, RangeTemperatureConstraint
|
||||
|
||||
|
||||
def create_mock_provider(model_name="gemini-2.5-flash", context_window=1_048_576):
|
||||
|
||||
@@ -8,9 +8,9 @@ both alias names and their target models, preventing policy bypass vulnerabiliti
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.analyze import AnalyzeTool
|
||||
from tools.chat import ChatTool
|
||||
from tools.debug import DebugIssueTool
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
|
||||
@@ -4,8 +4,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.models import ToolModelCategory
|
||||
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
# Mock get_model_config to return our test model
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
@@ -170,7 +170,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
# Mock get_model_config to return a model that supports temperature
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
@@ -227,7 +227,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
mock_registry = Mock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
test_capabilities = ModelCapabilities(
|
||||
provider=ProviderType.OPENAI,
|
||||
|
||||
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from providers import ModelProviderRegistry
|
||||
from providers.base import ProviderType
|
||||
from providers.custom import CustomProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestCustomProvider:
|
||||
|
||||
@@ -5,8 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.dial import DIALModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestDIALProvider:
|
||||
|
||||
@@ -8,7 +8,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelCapabilities, ModelProvider, ModelResponse, ProviderType
|
||||
from providers.base import ModelProvider
|
||||
from providers.shared import ModelCapabilities, ModelResponse, ProviderType
|
||||
|
||||
|
||||
class MinimalTestProvider(ModelProvider):
|
||||
|
||||
@@ -9,8 +9,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestIntelligentFallback:
|
||||
|
||||
@@ -41,7 +41,7 @@ def test_issue_245_custom_openai_temperature_ignored():
|
||||
mock_registry = Mock()
|
||||
mock_registry_class.return_value = mock_registry
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
from providers.shared import ModelCapabilities, ProviderType, create_temperature_constraint
|
||||
|
||||
# This is what the user configured in their custom_models.json
|
||||
custom_config = ModelCapabilities(
|
||||
|
||||
@@ -5,8 +5,9 @@ import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ModelProvider, ProviderType
|
||||
from providers.base import ModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.listmodels import ListModelsTool
|
||||
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class TestModelEnumeration:
|
||||
|
||||
# Rebuild the provider registry with OpenRouter registered
|
||||
ModelProviderRegistry._instance = None
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ This test specifically targets the bug where:
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.shared import ProviderType
|
||||
from tools.consensus import ConsensusTool
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ They prove that our fix was necessary and actually addresses real problems.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
from utils.model_restrictions import ModelRestrictionService
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
|
||||
@@ -5,9 +5,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ModelCapabilities, ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
from providers.shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
@@ -213,7 +213,7 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
from providers.base import create_temperature_constraint
|
||||
from providers.shared import create_temperature_constraint
|
||||
|
||||
caps = ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
|
||||
@@ -13,8 +13,8 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.chat import ChatTool
|
||||
from tools.shared.base_models import ToolRequest
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestProviderUTF8Encoding(unittest.TestCase):
|
||||
@@ -177,7 +177,7 @@ class TestProviderUTF8Encoding(unittest.TestCase):
|
||||
|
||||
def test_model_response_utf8_serialization(self):
|
||||
"""Test UTF-8 serialization of model responses."""
|
||||
from providers.base import ModelResponse
|
||||
from providers.shared import ModelResponse
|
||||
|
||||
response = ModelResponse(
|
||||
content="Development successful! Code generated successfully. 🎉✅",
|
||||
|
||||
@@ -6,9 +6,9 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from providers import ModelProviderRegistry, ModelResponse
|
||||
from providers.base import ProviderType
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.shared import ProviderType
|
||||
|
||||
|
||||
class TestModelProviderRegistry:
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestSupportedModelsAliases:
|
||||
for provider in providers:
|
||||
for model_name, config in provider.SUPPORTED_MODELS.items():
|
||||
# All values must be ModelCapabilities objects, not strings or dicts
|
||||
from providers.base import ModelCapabilities
|
||||
from providers.shared import ModelCapabilities
|
||||
|
||||
assert isinstance(config, ModelCapabilities), (
|
||||
f"{provider.__class__.__name__}.SUPPORTED_MODELS['{model_name}'] "
|
||||
|
||||
@@ -10,8 +10,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from tools.debug import DebugIssueTool
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.shared import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class TestXAIProvider:
|
||||
assert "grok-3-fast" in provider.SUPPORTED_MODELS
|
||||
|
||||
# Check model configs have required fields
|
||||
from providers.base import ModelCapabilities
|
||||
from providers.shared import ModelCapabilities
|
||||
|
||||
grok4_config = provider.SUPPORTED_MODELS["grok-4"]
|
||||
assert isinstance(grok4_config, ModelCapabilities)
|
||||
|
||||
@@ -22,9 +22,9 @@ def inject_transport(monkeypatch, cassette_path: str):
|
||||
transport = inject_transport(monkeypatch, "path/to/cassette.json")
|
||||
"""
|
||||
# Ensure OpenAI provider is registered - always needed for transport injection
|
||||
from providers.base import ProviderType
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
|
||||
# Always register OpenAI provider for transport tests (API key might be dummy)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
|
||||
Reference in New Issue
Block a user