Merge branch 'BeehiveInnovations:main' into feat-local_support_with_UTF-8_encoding-update
This commit is contained in:
@@ -527,7 +527,7 @@ class TestAutoModeComprehensive:
|
||||
"google/gemini-2.5-pro",
|
||||
"openai/o3",
|
||||
"openai/o4-mini",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-opus-4",
|
||||
]
|
||||
|
||||
with patch.object(OpenRouterProvider, "_registry", mock_registry):
|
||||
|
||||
273
tests/test_dial_provider.py
Normal file
273
tests/test_dial_provider.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for DIAL provider implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.dial import DIALModelProvider
|
||||
|
||||
|
||||
class TestDIALProvider:
|
||||
"""Test DIAL provider functionality."""
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": "https://test.dialx.ai"})
|
||||
def test_initialization_with_host(self):
|
||||
"""Test provider initialization with custom host."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider._dial_api_key == "test-key" # Check internal API key storage
|
||||
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
|
||||
assert provider.base_url == "https://test.dialx.ai/openai"
|
||||
assert provider.get_provider_type() == ProviderType.DIAL
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_KEY": "test-key", "DIAL_API_HOST": ""}, clear=True)
|
||||
def test_initialization_default_host(self):
|
||||
"""Test provider initialization with default host."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider._dial_api_key == "test-key" # Check internal API key storage
|
||||
assert provider.api_key == "placeholder-not-used" # OpenAI client uses placeholder, auth header removed by hook
|
||||
assert provider.base_url == "https://core.dialx.ai/openai"
|
||||
|
||||
def test_initialization_host_normalization(self):
|
||||
"""Test that host URL is normalized to include /openai suffix."""
|
||||
# Test with host missing /openai
|
||||
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai")
|
||||
assert provider.base_url == "https://custom.dialx.ai/openai"
|
||||
|
||||
# Test with host already having /openai
|
||||
provider = DIALModelProvider("test-key", base_url="https://custom.dialx.ai/openai")
|
||||
assert provider.base_url == "https://custom.dialx.ai/openai"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_model_validation(self):
|
||||
"""Test model name validation."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test valid models
|
||||
assert provider.validate_model_name("o3-2025-04-16") is True
|
||||
assert provider.validate_model_name("o3") is True # Shorthand
|
||||
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider.validate_model_name("opus-4") is True # Shorthand
|
||||
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is True
|
||||
assert provider.validate_model_name("gemini-2.5-pro") is True # Shorthand
|
||||
|
||||
# Test invalid model
|
||||
assert provider.validate_model_name("invalid-model") is False
|
||||
|
||||
def test_resolve_model_name(self):
|
||||
"""Test model name resolution for shorthands."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test shorthand resolution
|
||||
assert provider._resolve_model_name("o3") == "o3-2025-04-16"
|
||||
assert provider._resolve_model_name("o4-mini") == "o4-mini-2025-04-16"
|
||||
assert provider._resolve_model_name("opus-4") == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("sonnet-4") == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
assert provider._resolve_model_name("gemini-2.5-pro") == "gemini-2.5-pro-preview-05-06"
|
||||
assert provider._resolve_model_name("gemini-2.5-flash") == "gemini-2.5-flash-preview-05-20"
|
||||
|
||||
# Test full name passthrough
|
||||
assert provider._resolve_model_name("o3-2025-04-16") == "o3-2025-04-16"
|
||||
assert (
|
||||
provider._resolve_model_name("anthropic.claude-opus-4-20250514-v1:0")
|
||||
== "anthropic.claude-opus-4-20250514-v1:0"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_get_capabilities(self):
|
||||
"""Test getting model capabilities."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test O3 capabilities
|
||||
capabilities = provider.get_capabilities("o3")
|
||||
assert capabilities.model_name == "o3-2025-04-16"
|
||||
assert capabilities.friendly_name == "DIAL"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.provider == ProviderType.DIAL
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is False
|
||||
|
||||
# Test Claude 4 capabilities
|
||||
capabilities = provider.get_capabilities("opus-4")
|
||||
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is False
|
||||
|
||||
# Test Claude 4 with thinking mode
|
||||
capabilities = provider.get_capabilities("opus-4-thinking")
|
||||
assert capabilities.model_name == "anthropic.claude-opus-4-20250514-v1:0-with-thinking"
|
||||
assert capabilities.context_window == 200_000
|
||||
assert capabilities.supports_images is True
|
||||
assert capabilities.supports_extended_thinking is True
|
||||
|
||||
# Test Gemini capabilities
|
||||
capabilities = provider.get_capabilities("gemini-2.5-pro")
|
||||
assert capabilities.model_name == "gemini-2.5-pro-preview-05-06"
|
||||
assert capabilities.context_window == 1_000_000
|
||||
assert capabilities.supports_images is True
|
||||
|
||||
# Test temperature constraint
|
||||
assert capabilities.temperature_constraint.min_temp == 0.0
|
||||
assert capabilities.temperature_constraint.max_temp == 2.0
|
||||
assert capabilities.temperature_constraint.default_temp == 0.7
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_get_capabilities_invalid_model(self):
|
||||
"""Test that get_capabilities raises for invalid models."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported DIAL model"):
|
||||
provider.get_capabilities("invalid-model")
|
||||
|
||||
@patch("utils.model_restrictions.get_restriction_service")
|
||||
def test_get_capabilities_restricted_model(self, mock_get_restriction):
|
||||
"""Test that get_capabilities respects model restrictions."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Mock restriction service to block the model
|
||||
mock_service = MagicMock()
|
||||
mock_service.is_allowed.return_value = False
|
||||
mock_get_restriction.return_value = mock_service
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed by restriction policy"):
|
||||
provider.get_capabilities("o3")
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": ""}, clear=False)
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_supports_vision(self):
|
||||
"""Test vision support detection."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Test models with vision support
|
||||
assert provider._supports_vision("o3-2025-04-16") is True
|
||||
assert provider._supports_vision("o3") is True # Via resolution
|
||||
assert provider._supports_vision("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider._supports_vision("gemini-2.5-pro-preview-05-06") is True
|
||||
|
||||
# Test unknown model (falls back to parent implementation)
|
||||
assert provider._supports_vision("unknown-model") is False
|
||||
|
||||
@patch("openai.OpenAI") # Mock the OpenAI class directly from openai module
|
||||
def test_generate_content_with_alias(self, mock_openai_class):
|
||||
"""Test that generate_content properly resolves aliases and uses deployment routing."""
|
||||
# Create mock client
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
mock_response.model = "gpt-4"
|
||||
mock_response.id = "test-id"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Generate content with shorthand
|
||||
response = provider.generate_content(prompt="Test prompt", model_name="o3", temperature=0.7) # Shorthand
|
||||
|
||||
# Verify OpenAI was instantiated with deployment-specific URL
|
||||
mock_openai_class.assert_called_once()
|
||||
call_args = mock_openai_class.call_args
|
||||
assert "/deployments/o3-2025-04-16" in call_args[1]["base_url"]
|
||||
|
||||
# Verify the resolved model name was passed to the API
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
create_call_args = mock_client.chat.completions.create.call_args
|
||||
assert create_call_args[1]["model"] == "o3-2025-04-16" # Resolved name
|
||||
|
||||
# Verify response
|
||||
assert response.content == "Test response"
|
||||
assert response.model_name == "o3" # Original name preserved
|
||||
assert response.metadata["model"] == "gpt-4" # API returned model name from mock
|
||||
|
||||
def test_provider_type(self):
|
||||
"""Test provider type identification."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider.get_provider_type() == ProviderType.DIAL
|
||||
|
||||
def test_friendly_name(self):
|
||||
"""Test provider friendly name."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
assert provider.FRIENDLY_NAME == "DIAL"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_API_VERSION": "2024-12-01"})
|
||||
def test_configurable_api_version(self):
|
||||
"""Test that API version can be configured via environment variable."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
# Check that the custom API version is stored
|
||||
assert provider.api_version == "2024-12-01"
|
||||
|
||||
def test_default_api_version(self):
|
||||
"""Test that default API version is used when not configured."""
|
||||
# Clear any existing DIAL_API_VERSION from environment
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Keep other env vars but ensure DIAL_API_VERSION is not set
|
||||
if "DIAL_API_VERSION" in os.environ:
|
||||
del os.environ["DIAL_API_VERSION"]
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
# Check that the default API version is used
|
||||
assert provider.api_version == "2024-12-01-preview"
|
||||
# Check that Api-Key header is set
|
||||
assert provider.DEFAULT_HEADERS["Api-Key"] == "test-key"
|
||||
|
||||
@patch.dict(os.environ, {"DIAL_ALLOWED_MODELS": "o3-2025-04-16,anthropic.claude-opus-4-20250514-v1:0"})
|
||||
@patch("utils.model_restrictions._restriction_service", None)
|
||||
def test_allowed_models_restriction(self):
|
||||
"""Test model allow-list functionality."""
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# These should be allowed
|
||||
assert provider.validate_model_name("o3-2025-04-16") is True
|
||||
assert provider.validate_model_name("o3") is True # Alias for o3-2025-04-16
|
||||
assert provider.validate_model_name("anthropic.claude-opus-4-20250514-v1:0") is True
|
||||
assert provider.validate_model_name("opus-4") is True # Resolves to anthropic.claude-opus-4-20250514-v1:0
|
||||
|
||||
# These should be blocked
|
||||
assert provider.validate_model_name("gemini-2.5-pro-preview-05-06") is False
|
||||
assert provider.validate_model_name("o4-mini-2025-04-16") is False
|
||||
assert provider.validate_model_name("sonnet-4") is False # sonnet-4 is not in allowed list
|
||||
|
||||
@patch("httpx.Client")
|
||||
@patch("openai.OpenAI")
|
||||
def test_close_method(self, mock_openai_class, mock_httpx_client_class):
|
||||
"""Test that the close method properly closes HTTP clients."""
|
||||
# Mock the httpx.Client instance that DIALModelProvider will create
|
||||
mock_shared_http_client = MagicMock()
|
||||
mock_httpx_client_class.return_value = mock_shared_http_client
|
||||
|
||||
# Mock the OpenAI client instances
|
||||
mock_openai_client_1 = MagicMock()
|
||||
mock_openai_client_2 = MagicMock()
|
||||
# Configure side_effect to return different mocks for subsequent calls
|
||||
mock_openai_class.side_effect = [mock_openai_client_1, mock_openai_client_2]
|
||||
|
||||
provider = DIALModelProvider("test-key")
|
||||
|
||||
# Mock the superclass's _client attribute directly
|
||||
mock_superclass_client = MagicMock()
|
||||
provider._client = mock_superclass_client
|
||||
|
||||
# Simulate getting clients for two different deployments to populate _deployment_clients
|
||||
provider._get_deployment_client("model_a")
|
||||
provider._get_deployment_client("model_b")
|
||||
|
||||
# Now call close
|
||||
provider.close()
|
||||
|
||||
# Assert that the shared httpx client's close method was called
|
||||
mock_shared_http_client.close.assert_called_once()
|
||||
|
||||
# Assert that the superclass client's close method was called
|
||||
mock_superclass_client.close.assert_called_once()
|
||||
|
||||
# Assert that the deployment clients cache is cleared
|
||||
assert not provider._deployment_clients
|
||||
@@ -53,8 +53,8 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
# Set up mock to return only allowed models when restrictions are respected
|
||||
# Include both aliased models and full model names without aliases
|
||||
self.mock_openrouter.list_models.return_value = [
|
||||
"anthropic/claude-3-opus-20240229", # Has alias "opus"
|
||||
"anthropic/claude-3-sonnet-20240229", # Has alias "sonnet"
|
||||
"anthropic/claude-opus-4", # Has alias "opus"
|
||||
"anthropic/claude-sonnet-4", # Has alias "sonnet"
|
||||
"deepseek/deepseek-r1-0528:free", # No alias, full name
|
||||
"qwen/qwen3-235b-a22b-04-28:free", # No alias, full name
|
||||
]
|
||||
@@ -67,12 +67,12 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
def resolve_side_effect(model_name):
|
||||
if "opus" in model_name.lower():
|
||||
config = MagicMock()
|
||||
config.model_name = "anthropic/claude-3-opus-20240229"
|
||||
config.model_name = "anthropic/claude-opus-4-20240229"
|
||||
config.context_window = 200000
|
||||
return config
|
||||
elif "sonnet" in model_name.lower():
|
||||
config = MagicMock()
|
||||
config.model_name = "anthropic/claude-3-sonnet-20240229"
|
||||
config.model_name = "anthropic/claude-sonnet-4-20240229"
|
||||
config.context_window = 200000
|
||||
return config
|
||||
return None # No config for models without aliases
|
||||
@@ -93,8 +93,8 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
mock_get_models.return_value = {
|
||||
"gemini-2.5-flash": ProviderType.GOOGLE,
|
||||
"gemini-2.5-pro": ProviderType.GOOGLE,
|
||||
"anthropic/claude-3-opus-20240229": ProviderType.OPENROUTER,
|
||||
"anthropic/claude-3-sonnet-20240229": ProviderType.OPENROUTER,
|
||||
"anthropic/claude-opus-4-20240229": ProviderType.OPENROUTER,
|
||||
"anthropic/claude-sonnet-4-20240229": ProviderType.OPENROUTER,
|
||||
"deepseek/deepseek-r1-0528:free": ProviderType.OPENROUTER,
|
||||
"qwen/qwen3-235b-a22b-04-28:free": ProviderType.OPENROUTER,
|
||||
}
|
||||
@@ -172,7 +172,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
# Set up mock to return many models when no restrictions
|
||||
all_models = [f"provider{i//10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers
|
||||
all_models = [f"provider{i // 10}/model-{i}" for i in range(50)] # Simulate 50 models from different providers
|
||||
self.mock_openrouter.list_models.return_value = all_models
|
||||
|
||||
# Mock registry instance
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestModelRestrictionService:
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
|
||||
|
||||
# Should have no restrictions
|
||||
@@ -44,7 +44,7 @@ class TestModelRestrictionService:
|
||||
|
||||
# Google and OpenRouter should have no restrictions
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
|
||||
|
||||
def test_load_multiple_models_restriction(self):
|
||||
"""Test loading multiple allowed models."""
|
||||
@@ -159,7 +159,7 @@ class TestModelRestrictionService:
|
||||
# Should only allow specified OpenRouter models
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-opus", "opus") # With original name
|
||||
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4", "opus") # With original name
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
|
||||
assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestOpenRouterProvider:
|
||||
|
||||
# Should accept any model - OpenRouter handles validation
|
||||
assert provider.validate_model_name("gpt-4") is True
|
||||
assert provider.validate_model_name("claude-3-opus") is True
|
||||
assert provider.validate_model_name("claude-4-opus") is True
|
||||
assert provider.validate_model_name("any-model-name") is True
|
||||
assert provider.validate_model_name("GPT-4") is True
|
||||
assert provider.validate_model_name("unknown-model") is True
|
||||
@@ -71,26 +71,26 @@ class TestOpenRouterProvider:
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
|
||||
assert provider._resolve_model_name("opus") == "anthropic/claude-opus-4"
|
||||
assert provider._resolve_model_name("sonnet") == "anthropic/claude-sonnet-4"
|
||||
assert provider._resolve_model_name("o3") == "openai/o3"
|
||||
assert provider._resolve_model_name("o3-mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o3mini") == "openai/o3-mini"
|
||||
assert provider._resolve_model_name("o4-mini") == "openai/o4-mini"
|
||||
assert provider._resolve_model_name("o4-mini-high") == "openai/o4-mini-high"
|
||||
assert provider._resolve_model_name("claude") == "anthropic/claude-3-sonnet"
|
||||
assert provider._resolve_model_name("claude") == "anthropic/claude-sonnet-4"
|
||||
assert provider._resolve_model_name("mistral") == "mistralai/mistral-large-2411"
|
||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-r1-0528"
|
||||
assert provider._resolve_model_name("r1") == "deepseek/deepseek-r1-0528"
|
||||
|
||||
# Test case-insensitive
|
||||
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("OPUS") == "anthropic/claude-opus-4"
|
||||
assert provider._resolve_model_name("O3") == "openai/o3"
|
||||
assert provider._resolve_model_name("Mistral") == "mistralai/mistral-large-2411"
|
||||
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
|
||||
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-sonnet-4"
|
||||
|
||||
# Test direct model names (should pass through unchanged)
|
||||
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("anthropic/claude-opus-4") == "anthropic/claude-opus-4"
|
||||
assert provider._resolve_model_name("openai/o3") == "openai/o3"
|
||||
|
||||
# Test unknown models pass through
|
||||
@@ -155,8 +155,8 @@ class TestOpenRouterAutoMode:
|
||||
"google/gemini-2.5-pro",
|
||||
"openai/o3",
|
||||
"openai/o3-mini",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
]
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
@@ -181,7 +181,7 @@ class TestOpenRouterAutoMode:
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-openrouter-key"
|
||||
os.environ.pop("OPENROUTER_ALLOWED_MODELS", None)
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = "anthropic/claude-3-opus,google/gemini-2.5-flash"
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = "anthropic/claude-opus-4,google/gemini-2.5-flash"
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
|
||||
# Force reload to pick up new environment variable
|
||||
@@ -193,8 +193,8 @@ class TestOpenRouterAutoMode:
|
||||
mock_models = [
|
||||
"google/gemini-2.5-flash",
|
||||
"google/gemini-2.5-pro",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
]
|
||||
mock_registry.list_models.return_value = mock_models
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestOpenRouterAutoMode:
|
||||
|
||||
assert len(available_models) > 0, "Should have some allowed models"
|
||||
|
||||
expected_allowed = {"google/gemini-2.5-flash", "anthropic/claude-3-opus"}
|
||||
expected_allowed = {"google/gemini-2.5-flash", "anthropic/claude-opus-4"}
|
||||
|
||||
assert (
|
||||
set(available_models.keys()) == expected_allowed
|
||||
@@ -263,7 +263,7 @@ class TestOpenRouterRegistry:
|
||||
# Should have loaded models
|
||||
models = registry.list_models()
|
||||
assert len(models) > 0
|
||||
assert "anthropic/claude-3-opus" in models
|
||||
assert "anthropic/claude-opus-4" in models
|
||||
assert "openai/o3" in models
|
||||
|
||||
# Should have loaded aliases
|
||||
@@ -282,13 +282,13 @@ class TestOpenRouterRegistry:
|
||||
# Test known model
|
||||
caps = registry.get_capabilities("opus")
|
||||
assert caps is not None
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
assert caps.context_window == 200000 # Claude's context window
|
||||
|
||||
# Test using full model name
|
||||
caps = registry.get_capabilities("anthropic/claude-3-opus")
|
||||
caps = registry.get_capabilities("anthropic/claude-opus-4")
|
||||
assert caps is not None
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
|
||||
# Test unknown model
|
||||
caps = registry.get_capabilities("non-existent-model")
|
||||
@@ -301,11 +301,11 @@ class TestOpenRouterRegistry:
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# All these should resolve to Claude Sonnet
|
||||
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
|
||||
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude4-sonnet"]
|
||||
for alias in sonnet_aliases:
|
||||
config = registry.resolve(alias)
|
||||
assert config is not None
|
||||
assert config.model_name == "anthropic/claude-3-sonnet"
|
||||
assert config.model_name == "anthropic/claude-sonnet-4"
|
||||
|
||||
|
||||
class TestOpenRouterFunctionality:
|
||||
|
||||
@@ -74,9 +74,9 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
# Test various aliases
|
||||
test_cases = [
|
||||
("opus", "anthropic/claude-3-opus"),
|
||||
("OPUS", "anthropic/claude-3-opus"), # Case insensitive
|
||||
("claude", "anthropic/claude-3-sonnet"),
|
||||
("opus", "anthropic/claude-opus-4"),
|
||||
("OPUS", "anthropic/claude-opus-4"), # Case insensitive
|
||||
("claude", "anthropic/claude-sonnet-4"),
|
||||
("o3", "openai/o3"),
|
||||
("deepseek", "deepseek/deepseek-r1-0528"),
|
||||
("mistral", "mistralai/mistral-large-2411"),
|
||||
@@ -92,9 +92,9 @@ class TestOpenRouterModelRegistry:
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# Should be able to look up by full model name
|
||||
config = registry.resolve("anthropic/claude-3-opus")
|
||||
config = registry.resolve("anthropic/claude-opus-4")
|
||||
assert config is not None
|
||||
assert config.model_name == "anthropic/claude-3-opus"
|
||||
assert config.model_name == "anthropic/claude-opus-4"
|
||||
|
||||
config = registry.resolve("openai/o3")
|
||||
assert config is not None
|
||||
@@ -118,7 +118,7 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
assert caps.model_name == "anthropic/claude-opus-4"
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.context_window == 200000
|
||||
assert not caps.supports_extended_thinking
|
||||
|
||||
@@ -288,11 +288,11 @@ class TestProviderHelperMethods:
|
||||
with patch.object(ModelProviderRegistry, "get_provider") as mock_get_provider:
|
||||
# Mock openrouter provider
|
||||
mock_openrouter = MagicMock()
|
||||
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-3.5-sonnet"
|
||||
mock_openrouter.validate_model_name.side_effect = lambda m: m == "anthropic/claude-sonnet-4"
|
||||
mock_get_provider.side_effect = lambda ptype: mock_openrouter if ptype == ProviderType.OPENROUTER else None
|
||||
|
||||
model = ModelProviderRegistry._find_extended_thinking_model()
|
||||
assert model == "anthropic/claude-3.5-sonnet"
|
||||
assert model == "anthropic/claude-sonnet-4"
|
||||
|
||||
def test_find_extended_thinking_model_none_found(self):
|
||||
"""Test when no thinking model is found."""
|
||||
|
||||
@@ -318,7 +318,7 @@ class TestOpenRouterAliasRestrictions:
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
os.environ.pop("XAI_API_KEY", None)
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = "o3-mini,anthropic/claude-3-opus,flash"
|
||||
os.environ["OPENROUTER_ALLOWED_MODELS"] = "o3-mini,anthropic/claude-opus-4,flash"
|
||||
|
||||
# Register OpenRouter provider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
@@ -330,7 +330,7 @@ class TestOpenRouterAliasRestrictions:
|
||||
|
||||
expected_models = {
|
||||
"openai/o3-mini", # from alias
|
||||
"anthropic/claude-3-opus", # full name
|
||||
"anthropic/claude-opus-4", # full name
|
||||
"google/gemini-2.5-flash", # from alias
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user