feat!: breaking change - OpenRouter models are now read from conf/openrouter_models.json while Custom / Self-hosted models are read from conf/custom_models.json
feat: Azure OpenAI / Azure AI Foundry support. Models should be defined in conf/azure_models.json (or a custom path). See .env.example for environment variables or see readme. https://github.com/BeehiveInnovations/zen-mcp-server/issues/265 feat: OpenRouter / Custom Models / Azure can separately also use custom config paths now (see .env.example ) refactor: Model registry class made abstract, OpenRouter / Custom Provider / Azure OpenAI now subclass these refactor: breaking change: `is_custom` property has been removed from model_capabilities.py (and thus custom_models.json) given each models are now read from separate configuration files
This commit is contained in:
@@ -64,6 +64,14 @@ def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
# Ensure Azure provider stays disabled regardless of developer workstation env
|
||||
for azure_var in (
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_ALLOWED_MODELS",
|
||||
"AZURE_MODELS_CONFIG_PATH",
|
||||
):
|
||||
monkeypatch.delenv(azure_var, raising=False)
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
try:
|
||||
@@ -103,6 +111,13 @@ def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
|
||||
|
||||
for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
for azure_var in (
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_ALLOWED_MODELS",
|
||||
"AZURE_MODELS_CONFIG_PATH",
|
||||
):
|
||||
monkeypatch.delenv(azure_var, raising=False)
|
||||
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
model_restrictions._restriction_service = None
|
||||
@@ -136,6 +151,13 @@ def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, rese
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai")
|
||||
monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
|
||||
for azure_var in (
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_ALLOWED_MODELS",
|
||||
"AZURE_MODELS_CONFIG_PATH",
|
||||
):
|
||||
monkeypatch.delenv(azure_var, raising=False)
|
||||
env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
145
tests/test_azure_openai_provider.py
Normal file
145
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
if "openai" not in sys.modules: # pragma: no cover - test shim for optional dependency
|
||||
stub = types.ModuleType("openai")
|
||||
stub.AzureOpenAI = object # Replaced with a mock inside tests
|
||||
sys.modules["openai"] = stub
|
||||
|
||||
from providers.azure_openai import AzureOpenAIProvider
|
||||
from providers.shared import ModelCapabilities, ProviderType
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
def __init__(self):
|
||||
self.choices = [
|
||||
types.SimpleNamespace(
|
||||
message=types.SimpleNamespace(content="hello"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
self.model = "prod-gpt4o"
|
||||
self.id = "resp-123"
|
||||
self.created = 0
|
||||
self.usage = types.SimpleNamespace(
|
||||
prompt_tokens=5,
|
||||
completion_tokens=3,
|
||||
total_tokens=8,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_azure_client(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class _DummyAzureClient:
|
||||
def __init__(self, **kwargs):
|
||||
captured["client_kwargs"] = kwargs
|
||||
self.chat = types.SimpleNamespace(completions=types.SimpleNamespace(create=self._create_completion))
|
||||
self.responses = types.SimpleNamespace(create=self._create_response)
|
||||
|
||||
def _create_completion(self, **kwargs):
|
||||
captured["request_kwargs"] = kwargs
|
||||
return _DummyResponse()
|
||||
|
||||
def _create_response(self, **kwargs):
|
||||
captured["responses_kwargs"] = kwargs
|
||||
return _DummyResponse()
|
||||
|
||||
monkeypatch.delenv("AZURE_OPENAI_ALLOWED_MODELS", raising=False)
|
||||
monkeypatch.setattr("providers.azure_openai.AzureOpenAI", _DummyAzureClient)
|
||||
return captured
|
||||
|
||||
|
||||
def test_generate_content_uses_deployment_mapping(dummy_azure_client):
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="key",
|
||||
azure_endpoint="https://example.openai.azure.com/",
|
||||
deployments={"gpt-4o": "prod-gpt4o"},
|
||||
)
|
||||
|
||||
result = provider.generate_content("hello", "gpt-4o")
|
||||
|
||||
assert dummy_azure_client["request_kwargs"]["model"] == "prod-gpt4o"
|
||||
assert result.model_name == "gpt-4o"
|
||||
assert result.provider == ProviderType.AZURE
|
||||
assert provider.validate_model_name("prod-gpt4o")
|
||||
|
||||
|
||||
def test_generate_content_accepts_deployment_alias(dummy_azure_client):
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="key",
|
||||
azure_endpoint="https://example.openai.azure.com/",
|
||||
deployments={"gpt-4o-mini": "mini-deployment"},
|
||||
)
|
||||
|
||||
# Calling with the deployment alias should still resolve properly.
|
||||
result = provider.generate_content("hi", "mini-deployment")
|
||||
|
||||
assert dummy_azure_client["request_kwargs"]["model"] == "mini-deployment"
|
||||
assert result.model_name == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_client_initialization_uses_endpoint_and_version(dummy_azure_client):
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="key",
|
||||
azure_endpoint="https://example.openai.azure.com/",
|
||||
api_version="2024-03-15-preview",
|
||||
deployments={"gpt-4o": "prod"},
|
||||
)
|
||||
|
||||
_ = provider.client
|
||||
|
||||
assert dummy_azure_client["client_kwargs"]["azure_endpoint"] == "https://example.openai.azure.com"
|
||||
assert dummy_azure_client["client_kwargs"]["api_version"] == "2024-03-15-preview"
|
||||
|
||||
|
||||
def test_deployment_overrides_capabilities(dummy_azure_client):
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="key",
|
||||
azure_endpoint="https://example.openai.azure.com/",
|
||||
deployments={
|
||||
"gpt-4o": {
|
||||
"deployment": "prod-gpt4o",
|
||||
"friendly_name": "Azure GPT-4o EU",
|
||||
"intelligence_score": 19,
|
||||
"supports_temperature": False,
|
||||
"temperature_constraint": "fixed",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
caps = provider.get_capabilities("gpt-4o")
|
||||
assert caps.friendly_name == "Azure GPT-4o EU"
|
||||
assert caps.intelligence_score == 19
|
||||
assert not caps.supports_temperature
|
||||
|
||||
|
||||
def test_registry_configuration_merges_capabilities(dummy_azure_client, monkeypatch):
|
||||
def fake_registry_entries(self):
|
||||
capability = ModelCapabilities(
|
||||
provider=ProviderType.AZURE,
|
||||
model_name="gpt-4o",
|
||||
friendly_name="Azure GPT-4o Registry",
|
||||
context_window=500_000,
|
||||
max_output_tokens=128_000,
|
||||
)
|
||||
return {"gpt-4o": {"deployment": "registry-deployment", "capability": capability}}
|
||||
|
||||
monkeypatch.setattr(AzureOpenAIProvider, "_load_registry_entries", fake_registry_entries)
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="key",
|
||||
azure_endpoint="https://example.openai.azure.com/",
|
||||
)
|
||||
|
||||
# Capability should come from registry
|
||||
caps = provider.get_capabilities("gpt-4o")
|
||||
assert caps.friendly_name == "Azure GPT-4o Registry"
|
||||
assert caps.context_window == 500_000
|
||||
|
||||
# API call should use deployment defined in registry
|
||||
provider.generate_content("hello", "gpt-4o")
|
||||
assert dummy_azure_client["request_kwargs"]["model"] == "registry-deployment"
|
||||
@@ -34,8 +34,7 @@ class TestCustomOpenAITemperatureParameterFix:
|
||||
config_models = [
|
||||
{
|
||||
"model_name": "gpt-5-2025-08-07",
|
||||
"provider": "ProviderType.OPENAI",
|
||||
"is_custom": True,
|
||||
"provider": "openai",
|
||||
"context_window": 400000,
|
||||
"max_output_tokens": 128000,
|
||||
"supports_extended_thinking": True,
|
||||
|
||||
@@ -62,9 +62,9 @@ class TestCustomProvider:
|
||||
with pytest.raises(ValueError):
|
||||
provider.get_capabilities("o3")
|
||||
|
||||
# Test with a custom model (is_custom=true)
|
||||
# Test with a custom model from the local registry
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM # local-llama has is_custom=true
|
||||
assert capabilities.provider == ProviderType.CUSTOM
|
||||
assert capabilities.context_window > 0
|
||||
|
||||
finally:
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestModelEnumeration:
|
||||
# Configure environment with OpenRouter access only
|
||||
self._setup_environment({"OPENROUTER_API_KEY": "test-openrouter-key"})
|
||||
|
||||
# Create a temporary custom model config with a free variant
|
||||
# Create a temporary OpenRouter model config with a free variant
|
||||
custom_config = {
|
||||
"models": [
|
||||
{
|
||||
@@ -199,9 +199,9 @@ class TestModelEnumeration:
|
||||
]
|
||||
}
|
||||
|
||||
config_path = tmp_path / "custom_models.json"
|
||||
config_path = tmp_path / "openrouter_models.json"
|
||||
config_path.write_text(json.dumps(custom_config), encoding="utf-8")
|
||||
monkeypatch.setenv("CUSTOM_MODELS_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("OPENROUTER_MODELS_CONFIG_PATH", str(config_path))
|
||||
|
||||
# Reset cached registries so the temporary config is loaded
|
||||
from tools.shared.base_tool import BaseTool
|
||||
|
||||
@@ -366,8 +366,8 @@ class TestCustomProviderOpenRouterRestrictions:
|
||||
assert not provider.validate_model_name("sonnet")
|
||||
assert not provider.validate_model_name("haiku")
|
||||
|
||||
# Should still validate custom models (is_custom=true) regardless of restrictions
|
||||
assert provider.validate_model_name("local-llama") # This has is_custom=true
|
||||
# Should still validate custom models defined in conf/custom_models.json
|
||||
assert provider.validate_model_name("local-llama")
|
||||
|
||||
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
|
||||
def test_custom_provider_openrouter_capabilities_restrictions(self):
|
||||
@@ -389,7 +389,7 @@ class TestCustomProviderOpenRouterRestrictions:
|
||||
with pytest.raises(ValueError):
|
||||
provider.get_capabilities("haiku")
|
||||
|
||||
# Should still work for custom models (is_custom=true)
|
||||
# Should still work for custom models
|
||||
capabilities = provider.get_capabilities("local-llama")
|
||||
assert capabilities.provider == ProviderType.CUSTOM
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class TestOpenRouterAutoMode:
|
||||
def mock_resolve(model_name):
|
||||
if model_name in model_names:
|
||||
mock_config = Mock()
|
||||
mock_config.is_custom = False
|
||||
mock_config.provider = ProviderType.OPENROUTER
|
||||
mock_config.aliases = [] # Empty list of aliases
|
||||
mock_config.get_effective_capability_rank = Mock(return_value=50) # Add ranking method
|
||||
return mock_config
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -49,7 +50,7 @@ class TestOpenRouterModelRegistry:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_environment_variable_override(self):
|
||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||
"""Test OPENROUTER_MODELS_CONFIG_PATH environment variable."""
|
||||
# Create custom config
|
||||
config_data = {
|
||||
"models": [
|
||||
@@ -63,8 +64,8 @@ class TestOpenRouterModelRegistry:
|
||||
|
||||
try:
|
||||
# Set environment variable
|
||||
original_env = os.environ.get("CUSTOM_MODELS_CONFIG_PATH")
|
||||
os.environ["CUSTOM_MODELS_CONFIG_PATH"] = temp_path
|
||||
original_env = os.environ.get("OPENROUTER_MODELS_CONFIG_PATH")
|
||||
os.environ["OPENROUTER_MODELS_CONFIG_PATH"] = temp_path
|
||||
|
||||
# Create registry without explicit path
|
||||
registry = OpenRouterModelRegistry()
|
||||
@@ -76,9 +77,9 @@ class TestOpenRouterModelRegistry:
|
||||
finally:
|
||||
# Restore environment
|
||||
if original_env is not None:
|
||||
os.environ["CUSTOM_MODELS_CONFIG_PATH"] = original_env
|
||||
os.environ["OPENROUTER_MODELS_CONFIG_PATH"] = original_env
|
||||
else:
|
||||
del os.environ["CUSTOM_MODELS_CONFIG_PATH"]
|
||||
del os.environ["OPENROUTER_MODELS_CONFIG_PATH"]
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_alias_resolution(self):
|
||||
@@ -161,7 +162,7 @@ class TestOpenRouterModelRegistry:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_backwards_compatibility_max_tokens(self):
|
||||
"""Test that old max_tokens field is no longer supported (should result in empty registry)."""
|
||||
"""Test that legacy max_tokens field maps to max_output_tokens."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
@@ -178,19 +179,17 @@ class TestOpenRouterModelRegistry:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Should gracefully handle the error and result in empty registry
|
||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||
# Registry should be empty due to config error
|
||||
assert len(registry.list_models()) == 0
|
||||
assert len(registry.list_aliases()) == 0
|
||||
assert registry.resolve("old") is None
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with pytest.raises(ValueError, match="max_output_tokens"):
|
||||
OpenRouterModelRegistry(config_path=temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_missing_config_file(self):
|
||||
"""Test behavior with missing config file."""
|
||||
# Use a non-existent path
|
||||
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
||||
|
||||
# Should initialize with empty maps
|
||||
assert len(registry.list_models()) == 0
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for uvx path resolution functionality."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,8 +20,8 @@ class TestUvxPathResolution:
|
||||
def test_config_path_resolution(self):
|
||||
"""Test that the config path resolution finds the config file in multiple locations."""
|
||||
# Check that the config file exists in the development location
|
||||
config_file = Path(__file__).parent.parent / "conf" / "custom_models.json"
|
||||
assert config_file.exists(), "Config file should exist in conf/custom_models.json"
|
||||
config_file = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||
assert config_file.exists(), "Config file should exist in conf/openrouter_models.json"
|
||||
|
||||
# Test that a registry can find and use the config
|
||||
registry = OpenRouterModelRegistry()
|
||||
@@ -34,7 +36,7 @@ class TestUvxPathResolution:
|
||||
|
||||
def test_explicit_config_path_override(self):
|
||||
"""Test that explicit config path works correctly."""
|
||||
config_path = Path(__file__).parent.parent / "conf" / "custom_models.json"
|
||||
config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||
|
||||
registry = OpenRouterModelRegistry(config_path=str(config_path))
|
||||
|
||||
@@ -44,41 +46,62 @@ class TestUvxPathResolution:
|
||||
|
||||
def test_environment_variable_override(self):
|
||||
"""Test that CUSTOM_MODELS_CONFIG_PATH environment variable works."""
|
||||
config_path = Path(__file__).parent.parent / "conf" / "custom_models.json"
|
||||
config_path = Path(__file__).parent.parent / "conf" / "openrouter_models.json"
|
||||
|
||||
with patch.dict("os.environ", {"CUSTOM_MODELS_CONFIG_PATH": str(config_path)}):
|
||||
with patch.dict("os.environ", {"OPENROUTER_MODELS_CONFIG_PATH": str(config_path)}):
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# Should use environment path
|
||||
assert registry.config_path == config_path
|
||||
assert len(registry.list_models()) > 0
|
||||
|
||||
@patch("providers.openrouter_registry.importlib.resources.files")
|
||||
@patch("pathlib.Path.exists")
|
||||
def test_multiple_path_fallback(self, mock_exists, mock_files):
|
||||
"""Test that multiple path resolution works for different deployment scenarios."""
|
||||
# Make resources loading fail to trigger file system fallback
|
||||
@patch("providers.model_registry_base.importlib.resources.files")
|
||||
def test_multiple_path_fallback(self, mock_files):
|
||||
"""Test that file-system fallback works when resource loading fails."""
|
||||
mock_files.side_effect = Exception("Resource loading failed")
|
||||
|
||||
# Simulate dev path failing, and working directory path succeeding
|
||||
# The third `True` is for the check within `reload()`
|
||||
mock_exists.side_effect = [False, True, True]
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_dir = Path(tmpdir)
|
||||
conf_dir = temp_dir / "conf"
|
||||
conf_dir.mkdir(parents=True, exist_ok=True)
|
||||
config_path = conf_dir / "openrouter_models.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model",
|
||||
"aliases": ["testalias"],
|
||||
"context_window": 1024,
|
||||
"max_output_tokens": 512,
|
||||
}
|
||||
]
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
original_exists = Path.exists
|
||||
|
||||
# Should have fallen back to file system mode
|
||||
assert not registry.use_resources, "Should fall back to file system when resources fail"
|
||||
def fake_exists(path_self):
|
||||
if str(path_self).endswith("conf/openrouter_models.json") and path_self != config_path:
|
||||
return False
|
||||
if path_self == config_path:
|
||||
return True
|
||||
return original_exists(path_self)
|
||||
|
||||
# Assert that the registry fell back to the second potential path
|
||||
assert registry.config_path == Path.cwd() / "conf" / "custom_models.json"
|
||||
with patch("pathlib.Path.cwd", return_value=temp_dir), patch("pathlib.Path.exists", fake_exists):
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
# Should load models successfully
|
||||
assert len(registry.list_models()) > 0
|
||||
assert not registry.use_resources
|
||||
assert registry.config_path == config_path
|
||||
assert "test/model" in registry.list_models()
|
||||
|
||||
def test_missing_config_handling(self):
|
||||
"""Test behavior when config file is missing."""
|
||||
# Use a non-existent path
|
||||
registry = OpenRouterModelRegistry(config_path="/nonexistent/path/config.json")
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
registry = OpenRouterModelRegistry(config_path="/nonexistent/path/config.json")
|
||||
|
||||
# Should gracefully handle missing config
|
||||
assert len(registry.list_models()) == 0
|
||||
|
||||
@@ -166,8 +166,10 @@ class TestXAIProvider:
|
||||
"""Test model restrictions functionality."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
@@ -187,8 +189,10 @@ class TestXAIProvider:
|
||||
"""Test multiple models in restrictions."""
|
||||
# Clear cached restriction service
|
||||
import utils.model_restrictions
|
||||
from providers.registry import ModelProviderRegistry
|
||||
|
||||
utils.model_restrictions._restriction_service = None
|
||||
ModelProviderRegistry.reset_for_testing()
|
||||
|
||||
provider = XAIModelProvider("test-key")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user