Files
my-pal-mcp-server/tests/test_model_restrictions.py
Fahad 7c36b9255a refactor: moved registries into a separate module and code cleanup
fix: refactored dial provider to follow the same pattern
2025-10-07 12:59:09 +04:00

790 lines
35 KiB
Python

"""Tests for model restriction functionality."""
import os
from unittest.mock import MagicMock, patch
import pytest
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService
class TestModelRestrictionService:
"""Test cases for ModelRestrictionService."""
def test_no_restrictions_by_default(self):
"""Test that no restrictions exist when env vars are not set."""
with patch.dict(os.environ, {}, clear=True):
service = ModelRestrictionService()
# Should allow all models
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
# Should have no restrictions
assert not service.has_restrictions(ProviderType.OPENAI)
assert not service.has_restrictions(ProviderType.GOOGLE)
assert not service.has_restrictions(ProviderType.OPENROUTER)
def test_load_single_model_restriction(self):
"""Test loading a single allowed model."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"}):
service = ModelRestrictionService()
# Should only allow o3-mini
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
# Google and OpenRouter should have no restrictions
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
def test_load_multiple_models_restriction(self):
"""Test loading multiple allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so alias resolution for allow-lists is available
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
def fake_get_provider(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
service = ModelRestrictionService()
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def test_case_insensitive_and_whitespace_handling(self):
"""Test that model names are case-insensitive and whitespace is trimmed."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": " O3-MINI , o4-Mini "}):
service = ModelRestrictionService()
# Should work with any case
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "O3-MINI")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert service.is_allowed(ProviderType.OPENAI, "O4-Mini")
def test_empty_string_allows_all(self):
"""Test that empty string allows all models (same as unset)."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "", "GOOGLE_ALLOWED_MODELS": "flash"}):
service = ModelRestrictionService()
# OpenAI should allow all models (empty string = no restrictions)
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
# Google should only allow flash (and its resolved name)
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
def test_filter_models(self):
"""Test filtering a list of models based on restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
service = ModelRestrictionService()
models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
filtered = service.filter_models(ProviderType.OPENAI, models)
assert filtered == ["o3-mini", "o4-mini"]
def test_get_allowed_models(self):
"""Test getting the set of allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
service = ModelRestrictionService()
allowed = service.get_allowed_models(ProviderType.OPENAI)
assert allowed == {"o3-mini", "o4-mini"}
# No restrictions for Google
assert service.get_allowed_models(ProviderType.GOOGLE) is None
def test_shorthand_names_in_restrictions(self):
"""Test that shorthand names work in restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so the registry can resolve aliases
OpenAIModelProvider(api_key="test-key")
GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
# When providers check models, they pass both resolved and original names
# OpenAI: 'o4mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
# OpenAI: o3-mini allowed directly
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Google should allow both models via shorthands
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
# Also test that full names work when specified in restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3-mini", "o3mini") # Even with shorthand
def test_validation_against_known_models(self, caplog):
"""Test validation warnings for unknown models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mimi"}): # Note the typo: o4-mimi
service = ModelRestrictionService()
# Create mock provider with known models
mock_provider = MagicMock()
mock_provider.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
"o4-mini": {"context_window": 200000},
}
mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
provider_instances = {ProviderType.OPENAI: mock_provider}
service.validate_against_known_models(provider_instances)
# Should have logged a warning about the typo
assert "o4-mimi" in caplog.text
assert "not a recognized" in caplog.text
def test_openrouter_model_restrictions(self):
"""Test OpenRouter model restrictions functionality."""
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}):
service = ModelRestrictionService()
# Should only allow specified OpenRouter models
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4", "opus") # With original name
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
# Other providers should have no restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
# Should have restrictions for OpenRouter
assert service.has_restrictions(ProviderType.OPENROUTER)
assert not service.has_restrictions(ProviderType.OPENAI)
assert not service.has_restrictions(ProviderType.GOOGLE)
def test_openrouter_filter_models(self):
"""Test filtering OpenRouter models based on restrictions."""
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}):
service = ModelRestrictionService()
models = ["opus", "sonnet", "haiku", "mistral", "llama"]
filtered = service.filter_models(ProviderType.OPENROUTER, models)
assert filtered == ["opus", "mistral"]
def test_combined_provider_restrictions(self):
"""Test that restrictions work correctly when set for multiple providers."""
with patch.dict(
os.environ,
{
"OPENAI_ALLOWED_MODELS": "o3-mini",
"GOOGLE_ALLOWED_MODELS": "flash",
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet",
},
):
service = ModelRestrictionService()
# OpenAI restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Google restrictions
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
# OpenRouter restrictions
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
# All providers should have restrictions
assert service.has_restrictions(ProviderType.OPENAI)
assert service.has_restrictions(ProviderType.GOOGLE)
assert service.has_restrictions(ProviderType.OPENROUTER)
class TestProviderIntegration:
"""Test integration with actual providers."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"})
def test_openai_provider_respects_restrictions(self):
"""Test that OpenAI provider respects restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
# Should validate allowed model
assert provider.validate_model_name("o3-mini")
# Should not validate disallowed model
assert not provider.validate_model_name("o3")
# get_capabilities should raise for disallowed model
with pytest.raises(ValueError) as exc_info:
provider.get_capabilities("o3")
assert "not allowed by restriction policy" in str(exc_info.value)
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash,flash"})
def test_gemini_provider_respects_restrictions(self):
"""Test that Gemini provider respects restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Should validate allowed models (both shorthand and full name allowed)
assert provider.validate_model_name("flash")
assert provider.validate_model_name("gemini-2.5-flash")
# Should not validate disallowed model
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
# get_capabilities should raise for disallowed model
with pytest.raises(ValueError) as exc_info:
provider.get_capabilities("pro")
assert "not allowed by restriction policy" in str(exc_info.value)
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"})
def test_gemini_parameter_order_regression_protection(self):
"""Test that prevents regression of parameter order bug in is_allowed calls.
This test specifically catches the bug where parameters were incorrectly
passed as (provider, user_input, resolved_name) instead of
(provider, resolved_name, user_input).
The bug was subtle because the is_allowed method uses OR logic, so it
worked in most cases by accident. This test creates a scenario where
the parameter order matters.
"""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
# Should allow "flash" alias
assert provider.validate_model_name("flash")
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Canonical form should also be allowed now that alias is on the allowlist
assert provider.validate_model_name("gemini-2.5-flash")
# Unrelated models remain blocked
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
def test_gemini_parameter_order_edge_case_full_name_only(self):
"""Test parameter order with only full name allowed, not alias.
This is the reverse scenario - only the full canonical name is allowed,
not the shorthand alias. This tests that the parameter order is correct
when resolving aliases.
"""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Should allow full name
assert provider.validate_model_name("gemini-2.5-flash")
# Should also allow alias that resolves to allowed full name
# This works because is_allowed checks both resolved_name and original_name
assert provider.validate_model_name("flash")
# Should not allow "pro" alias
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
class TestCustomProviderOpenRouterRestrictions:
"""Test custom provider integration with OpenRouter restrictions."""
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_respects_openrouter_restrictions(self):
"""Test that custom provider correctly defers OpenRouter models to OpenRouter provider."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
assert not provider.validate_model_name("opus")
assert not provider.validate_model_name("sonnet")
assert not provider.validate_model_name("haiku")
# Should still validate custom models defined in conf/custom_models.json
assert provider.validate_model_name("local-llama")
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_openrouter_capabilities_restrictions(self):
"""Test that custom provider's get_capabilities correctly handles OpenRouter models."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# For OpenRouter models, CustomProvider should defer by raising
with pytest.raises(ValueError):
provider.get_capabilities("opus")
# Should raise for disallowed OpenRouter model (still defers)
with pytest.raises(ValueError):
provider.get_capabilities("haiku")
# Should still work for custom models
capabilities = provider.get_capabilities("local-llama")
assert capabilities.provider == ProviderType.CUSTOM
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False)
def test_custom_provider_no_openrouter_key_ignores_restrictions(self):
"""Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions."""
# Make sure OPENROUTER_API_KEY is not set
if "OPENROUTER_API_KEY" in os.environ:
del os.environ["OPENROUTER_API_KEY"]
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# Should not validate OpenRouter models when key is not available
assert not provider.validate_model_name("opus") # Even though it's in allowed list
assert not provider.validate_model_name("haiku")
# Should still validate custom models
assert provider.validate_model_name("local-llama")
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_empty_restrictions_allows_all_openrouter(self):
"""Test that custom provider correctly defers OpenRouter models regardless of restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
assert not provider.validate_model_name("opus")
assert not provider.validate_model_name("sonnet")
assert not provider.validate_model_name("haiku")
class TestRegistryIntegration:
"""Test integration with ModelProviderRegistry."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
def test_registry_with_shorthand_restrictions(self):
"""Test that registry handles shorthand restrictions correctly."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.registry import ModelProviderRegistry
# Clear registry cache
ModelProviderRegistry.clear_cache()
# Get available models with restrictions
# This test documents current behavior - get_available_models doesn't handle aliases
ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Currently, this will be empty because get_available_models doesn't
# recognize that "mini" allows "o4-mini"
# This is a known limitation that should be documented
@patch("providers.registry.ModelProviderRegistry.get_provider")
def test_get_available_models_respects_restrictions(self, mock_get_provider):
"""Test that registry filters models based on restrictions."""
from providers.registry import ModelProviderRegistry
# Mock providers
mock_openai = MagicMock()
mock_openai.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
}
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
mock_gemini = MagicMock()
mock_gemini.MODEL_CAPABILITIES = {
"gemini-2.5-pro": {"context_window": 1048576},
"gemini-2.5-flash": {"context_window": 1048576},
}
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
def gemini_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_gemini.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI:
return mock_openai
elif provider_type == ProviderType.GOOGLE:
return mock_gemini
return None
mock_get_provider.side_effect = get_provider_side_effect
# Set up registry with providers
registry = ModelProviderRegistry()
registry._providers = {
ProviderType.OPENAI: type(mock_openai),
ProviderType.GOOGLE: type(mock_gemini),
}
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}):
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
available = ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Should only include allowed models
assert "o3-mini" in available
assert "o3" not in available
assert "gemini-2.5-flash" in available
assert "gemini-2.5-pro" not in available
class TestShorthandRestrictions:
"""Test that shorthand model names work correctly in restrictions."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
def test_providers_validate_shorthands_correctly(self):
"""Test that providers correctly validate shorthand names."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
def registry_side_effect(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
assert openai_provider.validate_model_name("mini") # Should work with shorthand
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
assert not openai_provider.validate_model_name("o3-mini")
# Test Gemini provider
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
def test_multiple_shorthands_for_same_model(self):
"""Test that multiple shorthands work correctly."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
openai_provider = OpenAIModelProvider(api_key="test-key")
# Both shorthands should work
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
# Resolved names should be allowed when their shorthands are present
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
# Other models should not work
assert not openai_provider.validate_model_name("o3")
assert not openai_provider.validate_model_name("o3-pro")
@patch.dict(
os.environ,
{"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash"},
)
def test_both_shorthand_and_full_name_allowed(self):
"""Test that we can allow both shorthand and full names."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# OpenAI - both mini and o4-mini are allowed
openai_provider = OpenAIModelProvider(api_key="test-key")
assert openai_provider.validate_model_name("mini")
assert openai_provider.validate_model_name("o4-mini")
# Gemini - both flash and full name are allowed
gemini_provider = GeminiModelProvider(api_key="test-key")
assert gemini_provider.validate_model_name("flash")
assert gemini_provider.validate_model_name("gemini-2.5-flash")
class TestAutoModeWithRestrictions:
"""Test auto mode behavior with restrictions."""
@patch("providers.registry.ModelProviderRegistry.get_provider")
def test_fallback_model_respects_restrictions(self, mock_get_provider):
"""Test that fallback model selection respects restrictions."""
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
# Mock providers
mock_openai = MagicMock()
mock_openai.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
"o4-mini": {"context_window": 200000},
}
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
# Add get_preferred_model method to mock to match new implementation
def get_preferred_model(category, allowed_models):
# Simple preference logic for testing - just return first allowed model
return allowed_models[0] if allowed_models else None
mock_openai.get_preferred_model = get_preferred_model
def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI:
return mock_openai
return None
mock_get_provider.side_effect = get_provider_side_effect
# Set up registry
registry = ModelProviderRegistry()
registry._providers = {ProviderType.OPENAI: type(mock_openai)}
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}):
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Should pick o4-mini instead of o3-mini for fast response
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
assert model == "o4-mini"
def test_fallback_with_shorthand_restrictions(self, monkeypatch):
"""Test fallback model selection with shorthand restrictions."""
# Use monkeypatch to set environment variables with automatic cleanup
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
monkeypatch.setenv("GEMINI_API_KEY", "")
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
# Clear caches and reset registry
import utils.model_restrictions
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
utils.model_restrictions._restriction_service = None
# Store original providers for restoration
registry = ModelProviderRegistry()
original_providers = registry._providers.copy()
original_initialized = registry._initialized_providers.copy()
try:
# Clear registry and register only OpenAI and Gemini providers
ModelProviderRegistry._instance = None
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Even with "mini" restriction, fallback should work if provider handles it correctly
# This tests the real-world scenario
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# The fallback will depend on how get_available_models handles aliases
# When "mini" is allowed, it's returned as the allowed model
# "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
finally:
# Restore original registry state
registry = ModelProviderRegistry()
registry._providers.clear()
registry._initialized_providers.clear()
registry._providers.update(original_providers)
registry._initialized_providers.update(original_initialized)