fix: listmodels to always honor restricted models
fix: restrictions should resolve canonical names for openrouter fix: tools now correctly return restricted list by presenting model names in schema fix: tests updated to ensure these manage their expected env vars properly perf: cache model alias resolution to avoid repeated checks
This commit is contained in:
@@ -63,27 +63,30 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("o4mini")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini"}) # Allow alias only
|
||||
def test_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||
"""Test that restriction policy allows only the alias when just alias is specified.
|
||||
|
||||
If you restrict to 'mini' (which is an alias for gpt-5-mini),
|
||||
only the alias should work, not other models.
|
||||
This is the correct restrictive behavior.
|
||||
"""
|
||||
# Clear cached restriction service
|
||||
def test_restriction_policy_alias_allows_canonical(self):
|
||||
"""Alias-only allowlists should permit both the alias and its canonical target."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
# Only the alias should be allowed
|
||||
assert provider.validate_model_name("mini")
|
||||
# Direct target for this alias should NOT be allowed (mini -> gpt-5-mini)
|
||||
assert not provider.validate_model_name("gpt-5-mini")
|
||||
# Other models should NOT be allowed
|
||||
assert provider.validate_model_name("gpt-5-mini")
|
||||
assert not provider.validate_model_name("o4-mini")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"})
|
||||
def test_restriction_policy_alias_allows_short_name(self):
|
||||
"""Common aliases like 'gpt5' should allow their canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
|
||||
assert provider.validate_model_name("gpt5")
|
||||
assert provider.validate_model_name("gpt-5")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}) # Allow target
|
||||
def test_gemini_restriction_policy_allows_alias_when_target_allowed(self):
|
||||
"""Test Gemini restriction policy allows alias when target is allowed."""
|
||||
@@ -99,19 +102,16 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("flash")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}) # Allow alias only
|
||||
def test_gemini_restriction_policy_allows_only_alias_when_alias_specified(self):
|
||||
"""Test Gemini restriction policy allows only alias when just alias is specified."""
|
||||
# Clear cached restriction service
|
||||
def test_gemini_restriction_policy_alias_allows_canonical(self):
|
||||
"""Gemini alias allowlists should permit canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Only the alias should be allowed
|
||||
assert provider.validate_model_name("flash")
|
||||
# Direct target should NOT be allowed
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
def test_restriction_service_validation_includes_all_targets(self):
|
||||
"""Test that restriction service validation knows about all aliases and targets."""
|
||||
@@ -153,6 +153,30 @@ class TestAliasTargetRestrictions:
|
||||
assert provider.validate_model_name("o4-mini") # target
|
||||
assert provider.validate_model_name("o4mini") # alias for o4-mini
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "gpt5"}, clear=True)
|
||||
def test_service_alias_allows_canonical_openai(self):
|
||||
"""ModelRestrictionService should permit canonical names resolved from aliases."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
provider = OpenAIModelProvider(api_key="test-key")
|
||||
service = ModelRestrictionService()
|
||||
|
||||
assert service.is_allowed(ProviderType.OPENAI, "gpt-5")
|
||||
assert provider.validate_model_name("gpt-5")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"}, clear=True)
|
||||
def test_service_alias_allows_canonical_gemini(self):
|
||||
"""Gemini alias allowlists should permit canonical forms."""
|
||||
import utils.model_restrictions
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
service = ModelRestrictionService()
|
||||
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
|
||||
def test_alias_target_policy_regression_prevention(self):
|
||||
"""Regression test to ensure aliases and targets are both validated properly.
|
||||
|
||||
|
||||
@@ -106,19 +106,35 @@ class TestAutoMode:
|
||||
|
||||
def test_tool_schema_in_normal_mode(self):
|
||||
"""Test that tool schemas don't require model in normal mode"""
|
||||
# This test uses the default from conftest.py which sets non-auto mode
|
||||
# The conftest.py mock_provider_availability fixture ensures the model is available
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
# Save original
|
||||
original = os.environ.get("DEFAULT_MODEL", "")
|
||||
|
||||
# Model should not be required when default model is configured
|
||||
assert "model" not in schema["required"]
|
||||
try:
|
||||
# Set to a specific model (not auto mode)
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||
import config
|
||||
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "listmodels" in model_schema["description"]
|
||||
assert "default model" in model_schema["description"].lower()
|
||||
importlib.reload(config)
|
||||
|
||||
tool = ChatTool()
|
||||
schema = tool.get_input_schema()
|
||||
|
||||
# Model should not be required when default model is configured
|
||||
assert "model" not in schema["required"]
|
||||
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "listmodels" in model_schema["description"]
|
||||
assert "default model" in model_schema["description"].lower()
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
if original:
|
||||
os.environ["DEFAULT_MODEL"] = original
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
|
||||
203
tests/test_auto_mode_model_listing.py
Normal file
203
tests/test_auto_mode_model_listing.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests covering model restriction-aware error messaging in auto mode."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import utils.model_restrictions as model_restrictions
|
||||
from providers.gemini import GeminiModelProvider
|
||||
from providers.openai_provider import OpenAIModelProvider
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from providers.xai import XAIModelProvider
|
||||
|
||||
|
||||
def _extract_available_models(message: str) -> list[str]:
|
||||
"""Parse the available model list from the error message."""
|
||||
|
||||
marker = "Available models: "
|
||||
if marker not in message:
|
||||
raise AssertionError(f"Expected '{marker}' in message: {message}")
|
||||
|
||||
start = message.index(marker) + len(marker)
|
||||
end = message.find(". Suggested", start)
|
||||
if end == -1:
|
||||
end = len(message)
|
||||
|
||||
available_segment = message[start:end].strip()
|
||||
if not available_segment:
|
||||
return []
|
||||
|
||||
return [item.strip() for item in available_segment.split(",")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Ensure registry and restriction service state is isolated."""
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
yield
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
|
||||
|
||||
def _register_core_providers(*, include_xai: bool = False):
|
||||
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
if include_xai:
|
||||
ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
|
||||
"""Error payload should surface only the allowed models for each provider."""
|
||||
|
||||
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
|
||||
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
|
||||
monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
|
||||
monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
_register_core_providers()
|
||||
|
||||
import server
|
||||
|
||||
importlib.reload(server)
|
||||
|
||||
# Reload may have re-applied .env overrides; enforce our test configuration
|
||||
for key, value in (
|
||||
("DEFAULT_MODEL", "auto"),
|
||||
("GEMINI_API_KEY", "test-gemini"),
|
||||
("OPENAI_API_KEY", "test-openai"),
|
||||
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||
("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
|
||||
("OPENAI_ALLOWED_MODELS", "gpt-5"),
|
||||
("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
|
||||
("XAI_ALLOWED_MODELS", ""),
|
||||
):
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
server.configure_providers()
|
||||
|
||||
result = asyncio.run(
|
||||
server.handle_call_tool(
|
||||
"chat",
|
||||
{
|
||||
"model": "gpt5mini",
|
||||
"prompt": "Tell me about your strengths",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
payload = json.loads(result[0].text)
|
||||
assert payload["status"] == "error"
|
||||
|
||||
available_models = _extract_available_models(payload["content"])
|
||||
assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
|
||||
|
||||
|
||||
@pytest.mark.no_mock_provider
|
||||
def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
|
||||
"""When no restrictions are set, the full high-capability catalogue should appear."""
|
||||
|
||||
monkeypatch.setenv("DEFAULT_MODEL", "auto")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai")
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
for var in (
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
_register_core_providers(include_xai=True)
|
||||
|
||||
import server
|
||||
|
||||
importlib.reload(server)
|
||||
|
||||
for key, value in (
|
||||
("DEFAULT_MODEL", "auto"),
|
||||
("GEMINI_API_KEY", "test-gemini"),
|
||||
("OPENAI_API_KEY", "test-openai"),
|
||||
("OPENROUTER_API_KEY", "test-openrouter"),
|
||||
):
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
for var in (
|
||||
"GOOGLE_ALLOWED_MODELS",
|
||||
"OPENAI_ALLOWED_MODELS",
|
||||
"OPENROUTER_ALLOWED_MODELS",
|
||||
"XAI_ALLOWED_MODELS",
|
||||
"DIAL_ALLOWED_MODELS",
|
||||
"CUSTOM_API_URL",
|
||||
"CUSTOM_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
server.configure_providers()
|
||||
|
||||
result = asyncio.run(
|
||||
server.handle_call_tool(
|
||||
"chat",
|
||||
{
|
||||
"model": "dummymodel",
|
||||
"prompt": "Hi there",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
payload = json.loads(result[0].text)
|
||||
assert payload["status"] == "error"
|
||||
|
||||
available_models = _extract_available_models(payload["content"])
|
||||
assert "gemini-2.5-pro" in available_models
|
||||
assert "gpt-5" in available_models
|
||||
assert "grok-4" in available_models
|
||||
assert len(available_models) >= 5
|
||||
@@ -3,6 +3,7 @@ Tests for dynamic context request and collaboration features
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -157,95 +158,120 @@ class TestDynamicContextRequests:
|
||||
@patch("tools.shared.base_tool.BaseTool.get_model_provider")
|
||||
async def test_clarification_with_suggested_action(self, mock_get_provider, analyze_tool):
|
||||
"""Test clarification request with suggested next action"""
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
||||
"files_needed": ["config/database.yml", "src/db.py"],
|
||||
"suggested_next_action": {
|
||||
"tool": "analyze",
|
||||
"args": {
|
||||
"prompt": "Analyze database connection timeout issue",
|
||||
"relevant_files": [
|
||||
"/config/database.yml",
|
||||
"/src/db.py",
|
||||
"/logs/error.log",
|
||||
],
|
||||
import importlib
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Ensure deterministic model configuration for this test regardless of previous suites
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
original_default = os.environ.get("DEFAULT_MODEL")
|
||||
|
||||
try:
|
||||
os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
clarification_json = json.dumps(
|
||||
{
|
||||
"status": "files_required_to_continue",
|
||||
"mandatory_instructions": "I need to see the database configuration to analyze the connection error",
|
||||
"files_needed": ["config/database.yml", "src/db.py"],
|
||||
"suggested_next_action": {
|
||||
"tool": "analyze",
|
||||
"args": {
|
||||
"prompt": "Analyze database connection timeout issue",
|
||||
"relevant_files": [
|
||||
"/config/database.yml",
|
||||
"/src/db.py",
|
||||
"/logs/error.log",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
mock_provider = create_mock_provider()
|
||||
mock_provider.get_provider_type.return_value = Mock(value="google")
|
||||
mock_provider.generate_content.return_value = Mock(
|
||||
content=clarification_json, usage={}, model_name="gemini-2.5-flash", metadata={}
|
||||
)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"step": "Analyze database connection timeout issue",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Initial database timeout analysis",
|
||||
"relevant_files": ["/absolute/logs/error.log"],
|
||||
}
|
||||
)
|
||||
result = await analyze_tool.execute(
|
||||
{
|
||||
"step": "Analyze database connection timeout issue",
|
||||
"step_number": 1,
|
||||
"total_steps": 1,
|
||||
"next_step_required": False,
|
||||
"findings": "Initial database timeout analysis",
|
||||
"relevant_files": ["/absolute/logs/error.log"],
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
response_data = json.loads(result[0].text)
|
||||
response_data = json.loads(result[0].text)
|
||||
|
||||
# Workflow tools should either promote clarification status or handle it in expert analysis
|
||||
if response_data["status"] == "files_required_to_continue":
|
||||
# Clarification was properly promoted to main status
|
||||
# Check if mandatory_instructions is at top level or in content
|
||||
if "mandatory_instructions" in response_data:
|
||||
assert "database configuration" in response_data["mandatory_instructions"]
|
||||
assert "files_needed" in response_data
|
||||
assert "config/database.yml" in response_data["files_needed"]
|
||||
assert "src/db.py" in response_data["files_needed"]
|
||||
elif "content" in response_data:
|
||||
# Parse content JSON for workflow tools
|
||||
try:
|
||||
content_json = json.loads(response_data["content"])
|
||||
assert "mandatory_instructions" in content_json
|
||||
# Workflow tools should either promote clarification status or handle it in expert analysis
|
||||
if response_data["status"] == "files_required_to_continue":
|
||||
# Clarification was properly promoted to main status
|
||||
# Check if mandatory_instructions is at top level or in content
|
||||
if "mandatory_instructions" in response_data:
|
||||
assert "database configuration" in response_data["mandatory_instructions"]
|
||||
assert "files_needed" in response_data
|
||||
assert "config/database.yml" in response_data["files_needed"]
|
||||
assert "src/db.py" in response_data["files_needed"]
|
||||
elif "content" in response_data:
|
||||
# Parse content JSON for workflow tools
|
||||
try:
|
||||
content_json = json.loads(response_data["content"])
|
||||
assert "mandatory_instructions" in content_json
|
||||
assert (
|
||||
"database configuration" in content_json["mandatory_instructions"]
|
||||
or "database" in content_json["mandatory_instructions"]
|
||||
)
|
||||
assert "files_needed" in content_json
|
||||
files_needed_str = str(content_json["files_needed"])
|
||||
assert (
|
||||
"config/database.yml" in files_needed_str
|
||||
or "config" in files_needed_str
|
||||
or "database" in files_needed_str
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Content is not JSON, check if it contains required text
|
||||
content = response_data["content"]
|
||||
assert "database configuration" in content or "config" in content
|
||||
elif response_data["status"] == "calling_expert_analysis":
|
||||
# Clarification may be handled in expert analysis section
|
||||
if "expert_analysis" in response_data:
|
||||
expert_analysis = response_data["expert_analysis"]
|
||||
expert_content = str(expert_analysis)
|
||||
assert (
|
||||
"database configuration" in content_json["mandatory_instructions"]
|
||||
or "database" in content_json["mandatory_instructions"]
|
||||
"database configuration" in expert_content
|
||||
or "config/database.yml" in expert_content
|
||||
or "files_required_to_continue" in expert_content
|
||||
)
|
||||
assert "files_needed" in content_json
|
||||
files_needed_str = str(content_json["files_needed"])
|
||||
assert (
|
||||
"config/database.yml" in files_needed_str
|
||||
or "config" in files_needed_str
|
||||
or "database" in files_needed_str
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Content is not JSON, check if it contains required text
|
||||
content = response_data["content"]
|
||||
assert "database configuration" in content or "config" in content
|
||||
elif response_data["status"] == "calling_expert_analysis":
|
||||
# Clarification may be handled in expert analysis section
|
||||
if "expert_analysis" in response_data:
|
||||
expert_analysis = response_data["expert_analysis"]
|
||||
expert_content = str(expert_analysis)
|
||||
assert (
|
||||
"database configuration" in expert_content
|
||||
or "config/database.yml" in expert_content
|
||||
or "files_required_to_continue" in expert_content
|
||||
)
|
||||
else:
|
||||
# Some other status - ensure it's a valid workflow response
|
||||
assert "step_number" in response_data
|
||||
else:
|
||||
# Some other status - ensure it's a valid workflow response
|
||||
assert "step_number" in response_data
|
||||
|
||||
# Check for suggested next action
|
||||
if "suggested_next_action" in response_data:
|
||||
action = response_data["suggested_next_action"]
|
||||
assert action["tool"] == "analyze"
|
||||
# Check for suggested next action
|
||||
if "suggested_next_action" in response_data:
|
||||
action = response_data["suggested_next_action"]
|
||||
assert action["tool"] == "analyze"
|
||||
finally:
|
||||
if original_default is not None:
|
||||
os.environ["DEFAULT_MODEL"] = original_default
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
|
||||
import config
|
||||
|
||||
importlib.reload(config)
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
def test_tool_output_model_serialization(self):
|
||||
"""Test ToolOutput model serialization"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from providers.base import ModelProvider
|
||||
from providers.registry import ModelProviderRegistry
|
||||
from providers.shared import ProviderType
|
||||
from providers.shared import ModelCapabilities, ProviderType
|
||||
from tools.listmodels import ListModelsTool
|
||||
|
||||
|
||||
@@ -23,10 +23,63 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
self.mock_openrouter = MagicMock(spec=ModelProvider)
|
||||
self.mock_openrouter.provider_type = ProviderType.OPENROUTER
|
||||
|
||||
def make_capabilities(
|
||||
canonical: str, friendly: str, *, aliases=None, context: int = 200_000
|
||||
) -> ModelCapabilities:
|
||||
return ModelCapabilities(
|
||||
provider=ProviderType.OPENROUTER,
|
||||
model_name=canonical,
|
||||
friendly_name=friendly,
|
||||
intelligence_score=20,
|
||||
description=friendly,
|
||||
aliases=aliases or [],
|
||||
context_window=context,
|
||||
max_output_tokens=context,
|
||||
supports_extended_thinking=True,
|
||||
)
|
||||
|
||||
opus_caps = make_capabilities(
|
||||
"anthropic/claude-opus-4-20240229",
|
||||
"Claude Opus",
|
||||
aliases=["opus"],
|
||||
)
|
||||
sonnet_caps = make_capabilities(
|
||||
"anthropic/claude-sonnet-4-20240229",
|
||||
"Claude Sonnet",
|
||||
aliases=["sonnet"],
|
||||
)
|
||||
deepseek_caps = make_capabilities(
|
||||
"deepseek/deepseek-r1-0528:free",
|
||||
"DeepSeek R1",
|
||||
aliases=[],
|
||||
)
|
||||
qwen_caps = make_capabilities(
|
||||
"qwen/qwen3-235b-a22b-04-28:free",
|
||||
"Qwen3",
|
||||
aliases=[],
|
||||
)
|
||||
|
||||
self._openrouter_caps_map = {
|
||||
"anthropic/claude-opus-4": opus_caps,
|
||||
"opus": opus_caps,
|
||||
"anthropic/claude-opus-4-20240229": opus_caps,
|
||||
"anthropic/claude-sonnet-4": sonnet_caps,
|
||||
"sonnet": sonnet_caps,
|
||||
"anthropic/claude-sonnet-4-20240229": sonnet_caps,
|
||||
"deepseek/deepseek-r1-0528:free": deepseek_caps,
|
||||
"qwen/qwen3-235b-a22b-04-28:free": qwen_caps,
|
||||
}
|
||||
|
||||
self.mock_openrouter.get_capabilities.side_effect = self._openrouter_caps_map.__getitem__
|
||||
self.mock_openrouter.get_capabilities_by_rank.return_value = []
|
||||
self.mock_openrouter.list_models.return_value = []
|
||||
|
||||
# Create mock Gemini provider for comparison
|
||||
self.mock_gemini = MagicMock(spec=ModelProvider)
|
||||
self.mock_gemini.provider_type = ProviderType.GOOGLE
|
||||
self.mock_gemini.list_models.return_value = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||
self.mock_gemini.get_capabilities_by_rank.return_value = []
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests."""
|
||||
@@ -159,7 +212,7 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
for line in lines:
|
||||
if "OpenRouter" in line and "✅" in line:
|
||||
openrouter_section_found = True
|
||||
elif "Available Models" in line and openrouter_section_found:
|
||||
elif ("Models (policy restricted)" in line or "Available Models" in line) and openrouter_section_found:
|
||||
in_openrouter_section = True
|
||||
elif in_openrouter_section:
|
||||
# Check for lines with model names in backticks
|
||||
@@ -179,11 +232,11 @@ class TestListModelsRestrictions(unittest.TestCase):
|
||||
len(openrouter_models), 4, f"Expected 4 models, got {len(openrouter_models)}: {openrouter_models}"
|
||||
)
|
||||
|
||||
# Verify list_models was called with respect_restrictions=True
|
||||
self.mock_openrouter.list_models.assert_called_with(respect_restrictions=True)
|
||||
# Verify we did not fall back to unrestricted listing
|
||||
self.mock_openrouter.list_models.assert_not_called()
|
||||
|
||||
# Check for restriction note
|
||||
self.assertIn("Restricted to models matching:", result)
|
||||
self.assertIn("OpenRouter models restricted by", result)
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key", "GEMINI_API_KEY": "gemini-test-key"}, clear=True)
|
||||
@patch("providers.openrouter_registry.OpenRouterModelRegistry")
|
||||
|
||||
@@ -121,38 +121,59 @@ class TestModelMetadataContinuation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_previous_assistant_turn_defaults(self):
|
||||
"""Test behavior when there's no previous assistant turn."""
|
||||
thread_id = create_thread("chat", {"prompt": "test"})
|
||||
# Save and set DEFAULT_MODEL for test
|
||||
import importlib
|
||||
import os
|
||||
|
||||
# Only add user turns
|
||||
add_turn(thread_id, "user", "First question")
|
||||
add_turn(thread_id, "user", "Second question")
|
||||
original_default = os.environ.get("DEFAULT_MODEL", "")
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
import config
|
||||
import utils.model_context
|
||||
|
||||
arguments = {"continuation_id": thread_id}
|
||||
importlib.reload(config)
|
||||
importlib.reload(utils.model_context)
|
||||
|
||||
# Mock dependencies
|
||||
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
||||
mock_calc.return_value = MagicMock(
|
||||
total_tokens=200000,
|
||||
content_tokens=160000,
|
||||
response_tokens=40000,
|
||||
file_tokens=64000,
|
||||
history_tokens=64000,
|
||||
)
|
||||
try:
|
||||
thread_id = create_thread("chat", {"prompt": "test"})
|
||||
|
||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
||||
# Only add user turns
|
||||
add_turn(thread_id, "user", "First question")
|
||||
add_turn(thread_id, "user", "Second question")
|
||||
|
||||
# Call the actual function
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
arguments = {"continuation_id": thread_id}
|
||||
|
||||
# Should not have set a model
|
||||
assert enhanced_args.get("model") is None
|
||||
# Mock dependencies
|
||||
with patch("utils.model_context.ModelContext.calculate_token_allocation") as mock_calc:
|
||||
mock_calc.return_value = MagicMock(
|
||||
total_tokens=200000,
|
||||
content_tokens=160000,
|
||||
response_tokens=40000,
|
||||
file_tokens=64000,
|
||||
history_tokens=64000,
|
||||
)
|
||||
|
||||
# ModelContext should use DEFAULT_MODEL
|
||||
model_context = ModelContext.from_arguments(enhanced_args)
|
||||
from config import DEFAULT_MODEL
|
||||
with patch("utils.conversation_memory.build_conversation_history") as mock_build:
|
||||
mock_build.return_value = ("=== CONVERSATION HISTORY ===\n", 1000)
|
||||
|
||||
assert model_context.model_name == DEFAULT_MODEL
|
||||
# Call the actual function
|
||||
enhanced_args = await reconstruct_thread_context(arguments)
|
||||
|
||||
# Should not have set a model
|
||||
assert enhanced_args.get("model") is None
|
||||
|
||||
# ModelContext should use DEFAULT_MODEL
|
||||
model_context = ModelContext.from_arguments(enhanced_args)
|
||||
from config import DEFAULT_MODEL
|
||||
|
||||
assert model_context.model_name == DEFAULT_MODEL
|
||||
finally:
|
||||
# Restore original value
|
||||
if original_default:
|
||||
os.environ["DEFAULT_MODEL"] = original_default
|
||||
else:
|
||||
os.environ.pop("DEFAULT_MODEL", None)
|
||||
importlib.reload(config)
|
||||
importlib.reload(utils.model_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_model_overrides_previous_turn(self):
|
||||
|
||||
@@ -49,17 +49,32 @@ class TestModelRestrictionService:
|
||||
def test_load_multiple_models_restriction(self):
|
||||
"""Test loading multiple allowed models."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
service = ModelRestrictionService()
|
||||
# Instantiate providers so alias resolution for allow-lists is available
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
def fake_get_provider(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# Check OpenAI models
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o3")
|
||||
|
||||
# Check Google models
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "flash")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "pro")
|
||||
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
|
||||
|
||||
def test_case_insensitive_and_whitespace_handling(self):
|
||||
"""Test that model names are case-insensitive and whitespace is trimmed."""
|
||||
@@ -111,13 +126,17 @@ class TestModelRestrictionService:
|
||||
|
||||
def test_shorthand_names_in_restrictions(self):
|
||||
"""Test that shorthand names work in restrictions."""
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini,o3-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
|
||||
# Instantiate providers so the registry can resolve aliases
|
||||
OpenAIModelProvider(api_key="test-key")
|
||||
GeminiModelProvider(api_key="test-key")
|
||||
|
||||
service = ModelRestrictionService()
|
||||
|
||||
# When providers check models, they pass both resolved and original names
|
||||
# OpenAI: 'mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "mini") # How providers actually call it
|
||||
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini") # Direct check without original (for testing)
|
||||
# OpenAI: 'o4mini' shorthand allows o4-mini
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
|
||||
|
||||
# OpenAI: o3-mini allowed directly
|
||||
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
|
||||
@@ -280,19 +299,25 @@ class TestProviderIntegration:
|
||||
|
||||
provider = GeminiModelProvider(api_key="test-key")
|
||||
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
# Test case: Only alias "flash" is allowed, not the full name
|
||||
# If parameters are in wrong order, this test will catch it
|
||||
|
||||
# Test the edge case: Try to use full model name when only alias is allowed
|
||||
# This should NOT be allowed - only the alias "flash" is in the restriction list
|
||||
assert not provider.validate_model_name("gemini-2.5-flash")
|
||||
# Should allow "flash" alias
|
||||
assert provider.validate_model_name("flash")
|
||||
|
||||
# Should allow getting capabilities for "flash"
|
||||
capabilities = provider.get_capabilities("flash")
|
||||
assert capabilities.model_name == "gemini-2.5-flash"
|
||||
|
||||
# Canonical form should also be allowed now that alias is on the allowlist
|
||||
assert provider.validate_model_name("gemini-2.5-flash")
|
||||
# Unrelated models remain blocked
|
||||
assert not provider.validate_model_name("pro")
|
||||
assert not provider.validate_model_name("gemini-2.5-pro")
|
||||
|
||||
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
|
||||
def test_gemini_parameter_order_edge_case_full_name_only(self):
|
||||
@@ -570,17 +595,27 @@ class TestShorthandRestrictions:
|
||||
|
||||
# Test OpenAI provider
|
||||
openai_provider = OpenAIModelProvider(api_key="test-key")
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
# When restricting to "mini", you can't use "o4-mini" directly - this is correct behavior
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Not allowed - only shorthand is allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not allowed
|
||||
|
||||
# Test Gemini provider
|
||||
gemini_provider = GeminiModelProvider(api_key="test-key")
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
# Same for Gemini - if you restrict to "flash", you can't use the full name
|
||||
assert not gemini_provider.validate_model_name("gemini-2.5-flash") # Not allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
def registry_side_effect(provider_type, force_new=False):
|
||||
mapping = {
|
||||
ProviderType.OPENAI: openai_provider,
|
||||
ProviderType.GOOGLE: gemini_provider,
|
||||
}
|
||||
return mapping.get(provider_type)
|
||||
|
||||
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
|
||||
assert openai_provider.validate_model_name("mini") # Should work with shorthand
|
||||
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
|
||||
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
|
||||
assert not openai_provider.validate_model_name("o3-mini")
|
||||
|
||||
# Test Gemini provider
|
||||
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
|
||||
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
|
||||
assert not gemini_provider.validate_model_name("pro") # Not allowed
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
|
||||
def test_multiple_shorthands_for_same_model(self):
|
||||
@@ -596,9 +631,9 @@ class TestShorthandRestrictions:
|
||||
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
|
||||
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
|
||||
|
||||
# Resolved names work only if explicitly allowed
|
||||
# Resolved names should be allowed when their shorthands are present
|
||||
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
|
||||
assert not openai_provider.validate_model_name("o3-mini") # Not explicitly allowed, only shorthand
|
||||
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
|
||||
|
||||
# Other models should not work
|
||||
assert not openai_provider.validate_model_name("o3")
|
||||
|
||||
@@ -260,9 +260,10 @@ class TestOpenRouterAutoMode:
|
||||
os.environ["DEFAULT_MODEL"] = "auto"
|
||||
|
||||
mock_provider_class = Mock()
|
||||
mock_provider_instance = Mock(spec=["get_provider_type", "list_models"])
|
||||
mock_provider_instance = Mock(spec=["get_provider_type", "list_models", "get_all_model_capabilities"])
|
||||
mock_provider_instance.get_provider_type.return_value = ProviderType.OPENROUTER
|
||||
mock_provider_instance.list_models.return_value = []
|
||||
mock_provider_instance.get_all_model_capabilities.return_value = {}
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, mock_provider_class)
|
||||
|
||||
@@ -293,13 +293,7 @@ class TestOpenRouterAliasRestrictions:
|
||||
# o3 -> openai/o3
|
||||
# gpt4.1 -> should not exist (expected to be filtered out)
|
||||
|
||||
expected_models = {
|
||||
"openai/o3-mini",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-flash",
|
||||
"openai/o4-mini",
|
||||
"openai/o3",
|
||||
}
|
||||
expected_models = {"o3-mini", "pro", "flash", "o4-mini", "o3"}
|
||||
|
||||
available_model_names = set(available_models.keys())
|
||||
|
||||
@@ -355,9 +349,11 @@ class TestOpenRouterAliasRestrictions:
|
||||
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||
|
||||
expected_models = {
|
||||
"openai/o3-mini", # from alias
|
||||
"o3-mini", # alias
|
||||
"openai/o3-mini", # canonical
|
||||
"anthropic/claude-opus-4.1", # full name
|
||||
"google/gemini-2.5-flash", # from alias
|
||||
"flash", # alias
|
||||
"google/gemini-2.5-flash", # canonical
|
||||
}
|
||||
|
||||
available_model_names = set(available_models.keys())
|
||||
|
||||
Reference in New Issue
Block a user