fix: listmodels to always honor restricted models

fix: restrictions should resolve canonical names for openrouter
fix: tools now correctly return restricted list by presenting model names in schema
fix: tests updated to ensure these manage their expected env vars properly
perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
Fahad
2025-10-04 13:46:22 +04:00
parent 054e34e31c
commit 4015e917ed
17 changed files with 885 additions and 253 deletions

View File

@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test that restriction policy allows only the alias when just alias is specified.
If you restrict to 'mini' (which is an alias for gpt-5-mini),
only the alias should work, not other models.
This is the correct restrictive behavior.
"""
# Clear cached restriction service
def test_restriction_policy_alias_allows_canonical(self):
"""Alias-only allowlists should permit both the alias and its canonical target."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("mini")
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
assert not provider.validate_model_name("gpt-5-mini")
# Other models should NOT be allowed
assert provider.validate_model_name("gpt-5-mini")
assert not provider.validate_model_name("o4-mini")
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
def test_restriction_policy_alias_allows_short_name(self):
"""Common aliases like 'gpt5' should allow their canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
assert provider.validate_model_name("gpt5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
"""Test Gemini restriction policy allows alias when target is allowed."""
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("flash")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
"""Test Gemini restriction policy allows only alias when just alias is specified."""
# Clear cached restriction service
def test_gemini_restriction_policy_alias_allows_canonical(self):
"""Gemini alias allowlists should permit canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Only the alias should be allowed
assert provider.validate_model_name("flash")
# Direct target should NOT be allowed
assert not provider.validate_model_name("gemini-2.5-flash")
assert provider.validate_model_name("gemini-2.5-flash")
def test_restriction_service_validation_includes_all_targets(self):
"""Test that restriction service validation knows about all aliases and targets."""
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
assert provider.validate_model_name("o4-mini") # target
assert provider.validate_model_name("o4mini") # alias for o4-mini
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
def test_service_alias_allows_canonical_openai(self):
"""ModelRestrictionService should permit canonical names resolved from aliases."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
assert provider.validate_model_name("gpt-5")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
def test_service_alias_allows_canonical_gemini(self):
"""Gemini alias allowlists should permit canonical forms."""
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
assert provider.validate_model_name("gemini-2.5-flash")
def test_alias_target_policy_regression_prevention(self):
"""Regression test to ensure aliases and targets are both validated properly.

View File

@@ -106,19 +106,35 @@ class TestAutoMode:
def test_tool_schema_in_normal_mode(self):
"""Test that tool schemas don't require model in normal mode"""
# This test uses the default from conftest.py which sets non-auto mode
# The conftest.py mock_provider_availability fixture ensures the model is available
tool = ChatTool()
schema = tool.get_input_schema()
# Save original
original = os.environ.get("DEFAULT_MODEL", "")
# Model should not be required when default model is configured
assert "model" not in schema["required"]
try:
# Set to a specific model (not auto mode)
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
import config
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "listmodels" in model_schema["description"]
assert "default model" in model_schema["description"].lower()
importlib.reload(config)
tool = ChatTool()
schema = tool.get_input_schema()
# Model should not be required when default model is configured
assert "model" not in schema["required"]
# Model field should have simpler description
model_schema = schema["properties"]["model"]
assert "enum" not in model_schema
assert "listmodels" in model_schema["description"]
assert "default model" in model_schema["description"].lower()
finally:
# Restore
if original:
os.environ["DEFAULT_MODEL"] = original
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
@pytest.mark.asyncio
async def test_auto_mode_requires_model_parameter(self):

View File

@@ -0,0 +1,203 @@
"""Tests covering model restriction-aware error messaging in auto mode."""
import asyncio
import importlib
import json
import pytest
import utils.model_restrictions as model_restrictions
from providers.gemini import GeminiModelProvider
from providers.openai_provider import OpenAIModelProvider
from providers.openrouter import OpenRouterProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from providers.xai import XAIModelProvider
def _extract_available_models(message: str) -> list[str]:
"""Parse the available model list from the error message."""
marker = "Available models: "
if marker not in message:
raise AssertionError(f"Expected '{marker}' in message: {message}")
start = message.index(marker) + len(marker)
end = message.find(". Suggested", start)
if end == -1:
end = len(message)
available_segment = message[start:end].strip()
if not available_segment:
return []
return [item.strip() for item in available_segment.split(",")]
@pytest.fixture
def reset_registry():
"""Ensure registry and restriction service state is isolated."""
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
yield
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
def _register_core_providers(*, include_xai: bool = False):
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
if include_xai:
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
@pytest.mark.no_mock_provider
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
"""Error payload should surface only the allowed models for each provider."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
import config
importlib.reload(config)
_register_core_providers()
import server
importlib.reload(server)
# Reload may have re-applied .env overrides; enforce our test configuration
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
("OPENAI_ALLOWED_MODELS", "gpt-5"),
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
("XAI_ALLOWED_MODELS", ""),
):
monkeypatch.setenv(key, value)
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "gpt5mini",
"prompt": "Tell me about your strengths",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
@pytest.mark.no_mock_provider
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
"""When no restrictions are set, the full high-capability catalogue should appear."""
monkeypatch.setenv("DEFAULT_MODEL", "auto")
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
monkeypatch.setenv("XAI_API_KEY", "test-xai")
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
try:
import dotenv
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
except ModuleNotFoundError:
pass
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
):
monkeypatch.delenv(var, raising=False)
import config
importlib.reload(config)
_register_core_providers(include_xai=True)
import server
importlib.reload(server)
for key, value in (
("DEFAULT_MODEL", "auto"),
("GEMINI_API_KEY", "test-gemini"),
("OPENAI_API_KEY", "test-openai"),
("OPENROUTER_API_KEY", "test-openrouter"),
):
monkeypatch.setenv(key, value)
for var in (
"GOOGLE_ALLOWED_MODELS",
"OPENAI_ALLOWED_MODELS",
"OPENROUTER_ALLOWED_MODELS",
"XAI_ALLOWED_MODELS",
"DIAL_ALLOWED_MODELS",
"CUSTOM_API_URL",
"CUSTOM_API_KEY",
):
monkeypatch.delenv(var, raising=False)
ModelProviderRegistry.reset_for_testing()
model_restrictions._restriction_service = None
server.configure_providers()
result = asyncio.run(
server.handle_call_tool(
"chat",
{
"model": "dummymodel",
"prompt": "Hi there",
},
)
)
assert len(result) == 1
payload = json.loads(result[0].text)
assert payload["status"] == "error"
available_models = _extract_available_models(payload["content"])
assert "gemini-2.5-pro" in available_models
assert "gpt-5" in available_models
assert "grok-4" in available_models
assert len(available_models) >= 5

View File

@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
"""
import json
import os
from unittest.mock import Mock, patch
import pytest
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
"""Test clarification request with suggested next action"""
clarification_json = json.dumps(
{
"status": "files_required_to_continue",
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"],
"suggested_next_action": {
"tool": "analyze",
"args": {
"prompt": "Analyze database connection timeout issue",
"relevant_files": [
"/config/database.yml",
"/src/db.py",
"/logs/error.log",
],
import importlib
from providers.registry import ModelProviderRegistry
# Ensure deterministic model configuration for this test regardless of previous suites
ModelProviderRegistry.reset_for_testing()
original_default = os.environ.get("DEFAULT_MODEL")
try:
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
import config
importlib.reload(config)
clarification_json = json.dumps(
{
"status": "files_required_to_continue",
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
"files_needed": ["config/database.yml", "src/db.py"],
"suggested_next_action": {
"tool": "analyze",
"args": {
"prompt": "Analyze database connection timeout issue",
"relevant_files": [
"/config/database.yml",
"/src/db.py",
"/logs/error.log",
],
},
},
},
},
ensure_ascii=False,
)
ensure_ascii=False,
)
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
mock_provider = create_mock_provider()
mock_provider.get_provider_type.return_value = Mock(value="google")
mock_provider.generate_content.return_value = Mock(
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
)
mock_get_provider.return_value = mock_provider
result = await analyze_tool.execute(
{
"step": "Analyze database connection timeout issue",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Initial database timeout analysis",
"relevant_files": ["/absolute/logs/error.log"],
}
)
result = await analyze_tool.execute(
{
"step": "Analyze database connection timeout issue",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "Initial database timeout analysis",
"relevant_files": ["/absolute/logs/error.log"],
}
)
assert len(result) == 1
assert len(result) == 1
response_data = json.loads(result[0].text)
response_data = json.loads(result[0].text)
# Workflow tools should either promote clarification status or handle it in expert analysis
if response_data["status"] == "files_required_to_continue":
# Clarification was properly promoted to main status
# Check if mandatory_instructions is at top level or in content
if "mandatory_instructions" in response_data:
assert "database configuration" in response_data["mandatory_instructions"]
assert "files_needed" in response_data
assert "config/database.yml" in response_data["files_needed"]
assert "src/db.py" in response_data["files_needed"]
elif "content" in response_data:
# Parse content JSON for workflow tools
try:
content_json = json.loads(response_data["content"])
assert "mandatory_instructions" in content_json
# Workflow tools should either promote clarification status or handle it in expert analysis
if response_data["status"] == "files_required_to_continue":
# Clarification was properly promoted to main status
# Check if mandatory_instructions is at top level or in content
if "mandatory_instructions" in response_data:
assert "database configuration" in response_data["mandatory_instructions"]
assert "files_needed" in response_data
assert "config/database.yml" in response_data["files_needed"]
assert "src/db.py" in response_data["files_needed"]
elif "content" in response_data:
# Parse content JSON for workflow tools
try:
content_json = json.loads(response_data["content"])
assert "mandatory_instructions" in content_json
assert (
"database configuration" in content_json["mandatory_instructions"]
or "database" in content_json["mandatory_instructions"]
)
assert "files_needed" in content_json
files_needed_str = str(content_json["files_needed"])
assert (
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert (
"database configuration" in content_json["mandatory_instructions"]
or "database" in content_json["mandatory_instructions"]
"database configuration" in expert_content
or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
)
assert "files_needed" in content_json
files_needed_str = str(content_json["files_needed"])
assert (
"config/database.yml" in files_needed_str
or "config" in files_needed_str
or "database" in files_needed_str
)
except json.JSONDecodeError:
# Content is not JSON, check if it contains required text
content = response_data["content"]
assert "database configuration" in content or "config" in content
elif response_data["status"] == "calling_expert_analysis":
# Clarification may be handled in expert analysis section
if "expert_analysis" in response_data:
expert_analysis = response_data["expert_analysis"]
expert_content = str(expert_analysis)
assert (
"database configuration" in expert_content
or "config/database.yml" in expert_content
or "files_required_to_continue" in expert_content
)
else:
# Some other status - ensure it's a valid workflow response
assert "step_number" in response_data
else:
# Some other status - ensure it's a valid workflow response
assert "step_number" in response_data
# Check for suggested next action
if "suggested_next_action" in response_data:
action = response_data["suggested_next_action"]
assert action["tool"] == "analyze"
# Check for suggested next action
if "suggested_next_action" in response_data:
action = response_data["suggested_next_action"]
assert action["tool"] == "analyze"
finally:
if original_default is not None:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
import config
importlib.reload(config)
ModelProviderRegistry.reset_for_testing()
def test_tool_output_model_serialization(self):
"""Test ToolOutput model serialization"""

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
from providers.base import ModelProvider
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
from providers.shared import ModelCapabilities, ProviderType
from tools.listmodels import ListModelsTool
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
self.mock_openrouter = MagicMock(spec=ModelProvider)
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
def make_capabilities(
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
) -> ModelCapabilities:
return ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical,
friendly_name=friendly,
intelligence_score=20,
description=friendly,
aliases=aliases or [],
context_window=context,
max_output_tokens=context,
supports_extended_thinking=True,
)
opus_caps = make_capabilities(
"anthropic/claude-opus-4-20240229",
"Claude Opus",
aliases=["opus"],
)
sonnet_caps = make_capabilities(
"anthropic/claude-sonnet-4-20240229",
"Claude Sonnet",
aliases=["sonnet"],
)
deepseek_caps = make_capabilities(
"deepseek/deepseek-r1-0528:free",
"DeepSeek R1",
aliases=[],
)
qwen_caps = make_capabilities(
"qwen/qwen3-235b-a22b-04-28:free",
"Qwen3",
aliases=[],
)
self._openrouter_caps_map = {
"anthropic/claude-opus-4": opus_caps,
"opus": opus_caps,
"anthropic/claude-opus-4-20240229": opus_caps,
"anthropic/claude-sonnet-4": sonnet_caps,
"sonnet": sonnet_caps,
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
"deepseek/deepseek-r1-0528:free": deepseek_caps,
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
}
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
self.mock_openrouter.get_capabilities_by_rank.return_value = []
self.mock_openrouter.list_models.return_value = []
# Create mock Gemini provider for comparison
self.mock_gemini = MagicMock(spec=ModelProvider)
self.mock_gemini.provider_type = ProviderType.GOOGLE
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
self.mock_gemini.get_capabilities_by_rank.return_value = []
self.mock_gemini.get_capabilities_by_rank.return_value = []
def tearDown(self):
"""Clean up after tests."""
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
for line in lines:
if "OpenRouter" in line and "" in line:
openrouter_section_found = True
elif "Available Models" in line and openrouter_section_found:
elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
in_openrouter_section = True
elif in_openrouter_section:
# Check for lines with model names in backticks
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
)
# Verify list_models was called with respect_restrictions=True
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
# Verify we did not fall back to unrestricted listing
self.mock_openrouter.list_models.assert_not_called()
# Check for restriction note
self.assertIn("Restricted to models matching:", result)
self.assertIn("OpenRouter models restricted by", result)
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
@patch("providers.openrouter_registry.OpenRouterModelRegistry")

View File

@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
@pytest.mark.asyncio
async def test_no_previous_assistant_turn_defaults(self):
"""Test behavior when there's no previous assistant turn."""
thread_id = create_thread("chat", {"prompt": "test"})
# Save and set DEFAULT_MODEL for test
import importlib
import os
# Only add user turns
add_turn(thread_id, "user", "First question")
add_turn(thread_id, "user", "Second question")
original_default = os.environ.get("DEFAULT_MODEL", "")
os.environ["DEFAULT_MODEL"] = "auto"
import config
import utils.model_context
arguments = {"continuation_id": thread_id}
importlib.reload(config)
importlib.reload(utils.model_context)
# Mock dependencies
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
try:
thread_id = create_thread("chat", {"prompt": "test"})
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
# Only add user turns
add_turn(thread_id, "user", "First question")
add_turn(thread_id, "user", "Second question")
# Call the actual function
enhanced_args = await reconstruct_thread_context(arguments)
arguments = {"continuation_id": thread_id}
# Should not have set a model
assert enhanced_args.get("model") is None
# Mock dependencies
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
mock_calc.return_value = MagicMock(
total_tokens=200000,
content_tokens=160000,
response_tokens=40000,
file_tokens=64000,
history_tokens=64000,
)
# ModelContext should use DEFAULT_MODEL
model_context = ModelContext.from_arguments(enhanced_args)
from config import DEFAULT_MODEL
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
assert model_context.model_name == DEFAULT_MODEL
# Call the actual function
enhanced_args = await reconstruct_thread_context(arguments)
# Should not have set a model
assert enhanced_args.get("model") is None
# ModelContext should use DEFAULT_MODEL
model_context = ModelContext.from_arguments(enhanced_args)
from config import DEFAULT_MODEL
assert model_context.model_name == DEFAULT_MODEL
finally:
# Restore original value
if original_default:
os.environ["DEFAULT_MODEL"] = original_default
else:
os.environ.pop("DEFAULT_MODEL", None)
importlib.reload(config)
importlib.reload(utils.model_context)
@pytest.mark.asyncio
async def test_explicit_model_overrides_previous_turn(self):

View File

@@ -49,17 +49,32 @@ class TestModelRestrictionService:
def test_load_multiple_models_restriction(self):
"""Test loading multiple allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
service = ModelRestrictionService()
# Instantiate providers so alias resolution for allow-lists is available
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
from providers.registry import ModelProviderRegistry
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def fake_get_provider(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
service = ModelRestrictionService()
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def test_case_insensitive_and_whitespace_handling(self):
"""Test that model names are case-insensitive and whitespace is trimmed."""
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
def test_shorthand_names_in_restrictions(self):
"""Test that shorthand names work in restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so the registry can resolve aliases
OpenAIModelProvider(api_key="test-key")
GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
# When providers check models, they pass both resolved and original names
# OpenAI: 'mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
# OpenAI: 'o4mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
# OpenAI: o3-mini allowed directly
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
@@ -280,19 +299,25 @@ class TestProviderIntegration:
provider = GeminiModelProvider(api_key="test-key")
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
from providers.registry import ModelProviderRegistry
# Should allow "flash" alias
assert provider.validate_model_name("flash")
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
# Test the edge case: Try to use full model name when only alias is allowed
# This should NOT be allowed - only the alias "flash" is in the restriction list
assert not provider.validate_model_name("gemini-2.5-flash")
# Should allow "flash" alias
assert provider.validate_model_name("flash")
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Canonical form should also be allowed now that alias is on the allowlist
assert provider.validate_model_name("gemini-2.5-flash")
# Unrelated models remain blocked
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
def test_gemini_parameter_order_edge_case_full_name_only(self):
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
# Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key")
assert openai_provider.validate_model_name("mini") # Should work with shorthand
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
# Test Gemini provider
gemini_provider = GeminiModelProvider(api_key="test-key")
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
# Same for Gemini - if you restrict to "flash", you can't use the full name
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
from providers.registry import ModelProviderRegistry
def registry_side_effect(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
assert openai_provider.validate_model_name("mini") # Should work with shorthand
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
assert not openai_provider.validate_model_name("o3-mini")
# Test Gemini provider
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
def test_multiple_shorthands_for_same_model(self):
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
# Resolved names work only if explicitly allowed
# Resolved names should be allowed when their shorthands are present
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
# Other models should not work
assert not openai_provider.validate_model_name("o3")

View File

@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
os.environ["DEFAULT_MODEL"] = "auto"
mock_provider_class = Mock()
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
mock_provider_instance.list_models.return_value = []
mock_provider_instance.get_all_model_capabilities.return_value = {}
mock_provider_class.return_value = mock_provider_instance
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)

View File

@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
# o3 -> openai/o3
# gpt4.1 -> should not exist (expected to be filtered out)
expected_models = {
"openai/o3-mini",
"google/gemini-2.5-pro",
"google/gemini-2.5-flash",
"openai/o4-mini",
"openai/o3",
}
expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
available_model_names = set(available_models.keys())
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
expected_models = {
"openai/o3-mini", # from alias
"o3-mini", # alias
"openai/o3-mini", # canonical
"anthropic/claude-opus-4.1", # full name
"google/gemini-2.5-flash", # from alias
"flash", # alias
"google/gemini-2.5-flash", # canonical
}
available_model_names = set(available_models.keys())