WIP
- OpenRouter model configuration registry - Model definition file for users to be able to control - Additional tests - Update instructions
This commit is contained in:
@@ -97,7 +97,8 @@ class TestAutoMode:
|
||||
# Model field should have simpler description
|
||||
model_schema = schema["properties"]["model"]
|
||||
assert "enum" not in model_schema
|
||||
assert "Available:" in model_schema["description"]
|
||||
assert "Native models:" in model_schema["description"]
|
||||
assert "Defaults to" in model_schema["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_requires_model_parameter(self):
|
||||
@@ -180,8 +181,9 @@ class TestAutoMode:
|
||||
|
||||
schema = tool.get_model_field_schema()
|
||||
assert "enum" not in schema
|
||||
assert "Available:" in schema["description"]
|
||||
assert "Native models:" in schema["description"]
|
||||
assert "'pro'" in schema["description"]
|
||||
assert "Defaults to" in schema["description"]
|
||||
|
||||
finally:
|
||||
# Restore
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Tests for OpenRouter provider."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter import OpenRouterProvider
|
||||
@@ -11,65 +10,64 @@ from providers.registry import ModelProviderRegistry
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
"""Test cases for OpenRouter provider."""
|
||||
|
||||
|
||||
def test_provider_initialization(self):
|
||||
"""Test OpenRouter provider initialization."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||
assert provider.FRIENDLY_NAME == "OpenRouter"
|
||||
|
||||
|
||||
def test_custom_headers(self):
|
||||
"""Test OpenRouter custom headers."""
|
||||
# Test default headers
|
||||
assert "HTTP-Referer" in OpenRouterProvider.DEFAULT_HEADERS
|
||||
assert "X-Title" in OpenRouterProvider.DEFAULT_HEADERS
|
||||
|
||||
|
||||
# Test with environment variables
|
||||
with patch.dict(os.environ, {
|
||||
"OPENROUTER_REFERER": "https://myapp.com",
|
||||
"OPENROUTER_TITLE": "My App"
|
||||
}):
|
||||
with patch.dict(os.environ, {"OPENROUTER_REFERER": "https://myapp.com", "OPENROUTER_TITLE": "My App"}):
|
||||
from importlib import reload
|
||||
|
||||
import providers.openrouter
|
||||
|
||||
reload(providers.openrouter)
|
||||
|
||||
|
||||
provider = providers.openrouter.OpenRouterProvider(api_key="test-key")
|
||||
assert provider.DEFAULT_HEADERS["HTTP-Referer"] == "https://myapp.com"
|
||||
assert provider.DEFAULT_HEADERS["X-Title"] == "My App"
|
||||
|
||||
|
||||
def test_model_validation(self):
|
||||
"""Test model validation."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
|
||||
# Should accept any model - OpenRouter handles validation
|
||||
assert provider.validate_model_name("gpt-4") is True
|
||||
assert provider.validate_model_name("claude-3-opus") is True
|
||||
assert provider.validate_model_name("any-model-name") is True
|
||||
assert provider.validate_model_name("GPT-4") is True
|
||||
assert provider.validate_model_name("unknown-model") is True
|
||||
|
||||
|
||||
def test_get_capabilities(self):
|
||||
"""Test capability generation."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
|
||||
# Test with a model in the registry (using alias)
|
||||
caps = provider.get_capabilities("gpt4o")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "openai/gpt-4o" # Resolved name
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
|
||||
|
||||
# Test with a model not in registry - should get generic capabilities
|
||||
caps = provider.get_capabilities("unknown-model")
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "unknown-model"
|
||||
assert caps.max_tokens == 32_768 # Safe default
|
||||
assert hasattr(caps, '_is_generic') and caps._is_generic is True
|
||||
|
||||
assert hasattr(caps, "_is_generic") and caps._is_generic is True
|
||||
|
||||
def test_model_alias_resolution(self):
|
||||
"""Test model alias resolution."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
|
||||
# Test alias resolution
|
||||
assert provider._resolve_model_name("opus") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("sonnet") == "anthropic/claude-3-sonnet"
|
||||
@@ -79,30 +77,30 @@ class TestOpenRouterProvider:
|
||||
assert provider._resolve_model_name("mistral") == "mistral/mistral-large"
|
||||
assert provider._resolve_model_name("deepseek") == "deepseek/deepseek-coder"
|
||||
assert provider._resolve_model_name("coder") == "deepseek/deepseek-coder"
|
||||
|
||||
|
||||
# Test case-insensitive
|
||||
assert provider._resolve_model_name("OPUS") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("GPT4O") == "openai/gpt-4o"
|
||||
assert provider._resolve_model_name("Mistral") == "mistral/mistral-large"
|
||||
assert provider._resolve_model_name("CLAUDE") == "anthropic/claude-3-sonnet"
|
||||
|
||||
|
||||
# Test direct model names (should pass through unchanged)
|
||||
assert provider._resolve_model_name("anthropic/claude-3-opus") == "anthropic/claude-3-opus"
|
||||
assert provider._resolve_model_name("openai/gpt-4o") == "openai/gpt-4o"
|
||||
|
||||
|
||||
# Test unknown models pass through
|
||||
assert provider._resolve_model_name("unknown-model") == "unknown-model"
|
||||
assert provider._resolve_model_name("custom/model-v2") == "custom/model-v2"
|
||||
|
||||
|
||||
def test_openrouter_registration(self):
|
||||
"""Test OpenRouter can be registered and retrieved."""
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
# Clean up any existing registration
|
||||
ModelProviderRegistry.unregister_provider(ProviderType.OPENROUTER)
|
||||
|
||||
|
||||
# Register the provider
|
||||
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||
|
||||
|
||||
# Retrieve and verify
|
||||
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||
assert provider is not None
|
||||
@@ -111,53 +109,53 @@ class TestOpenRouterProvider:
|
||||
|
||||
class TestOpenRouterRegistry:
|
||||
"""Test cases for OpenRouter model registry."""
|
||||
|
||||
|
||||
def test_registry_loading(self):
|
||||
"""Test registry loads models from config."""
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Should have loaded models
|
||||
models = registry.list_models()
|
||||
assert len(models) > 0
|
||||
assert "anthropic/claude-3-opus" in models
|
||||
assert "openai/gpt-4o" in models
|
||||
|
||||
|
||||
# Should have loaded aliases
|
||||
aliases = registry.list_aliases()
|
||||
assert len(aliases) > 0
|
||||
assert "opus" in aliases
|
||||
assert "gpt4o" in aliases
|
||||
assert "claude" in aliases
|
||||
|
||||
|
||||
def test_registry_capabilities(self):
|
||||
"""Test registry provides correct capabilities."""
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Test known model
|
||||
caps = registry.get_capabilities("opus")
|
||||
assert caps is not None
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
assert caps.max_tokens == 200000 # Claude's context window
|
||||
|
||||
|
||||
# Test using full model name
|
||||
caps = registry.get_capabilities("anthropic/claude-3-opus")
|
||||
assert caps is not None
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
|
||||
|
||||
# Test unknown model
|
||||
caps = registry.get_capabilities("non-existent-model")
|
||||
assert caps is None
|
||||
|
||||
|
||||
def test_multiple_aliases_same_model(self):
|
||||
"""Test multiple aliases pointing to same model."""
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||
|
||||
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# All these should resolve to Claude Sonnet
|
||||
sonnet_aliases = ["sonnet", "claude", "claude-sonnet", "claude3-sonnet"]
|
||||
for alias in sonnet_aliases:
|
||||
@@ -166,48 +164,34 @@ class TestOpenRouterRegistry:
|
||||
assert config.model_name == "anthropic/claude-3-sonnet"
|
||||
|
||||
|
||||
class TestOpenRouterSSRFProtection:
|
||||
"""Test SSRF protection for OpenRouter."""
|
||||
|
||||
def test_url_validation_rejects_private_ips(self):
|
||||
"""Test that private IPs are rejected."""
|
||||
class TestOpenRouterFunctionality:
|
||||
"""Test OpenRouter-specific functionality."""
|
||||
|
||||
def test_openrouter_always_uses_correct_url(self):
|
||||
"""Test that OpenRouter always uses the correct base URL."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# List of private/dangerous IPs to test
|
||||
dangerous_urls = [
|
||||
"http://192.168.1.1/api/v1",
|
||||
"http://10.0.0.1/api/v1",
|
||||
"http://172.16.0.1/api/v1",
|
||||
"http://169.254.169.254/api/v1", # AWS metadata
|
||||
"http://[::1]/api/v1", # IPv6 localhost
|
||||
"http://0.0.0.0/api/v1",
|
||||
]
|
||||
|
||||
for url in dangerous_urls:
|
||||
with pytest.raises(ValueError, match="restricted IP|Invalid"):
|
||||
provider.base_url = url
|
||||
provider._validate_base_url()
|
||||
|
||||
def test_url_validation_allows_public_domains(self):
|
||||
"""Test that legitimate public domains are allowed."""
|
||||
assert provider.base_url == "https://openrouter.ai/api/v1"
|
||||
|
||||
# Even if we try to change it, it should remain the OpenRouter URL
|
||||
# (This is a characteristic of the OpenRouter provider)
|
||||
provider.base_url = "http://example.com" # Try to change it
|
||||
# But new instances should always use the correct URL
|
||||
provider2 = OpenRouterProvider(api_key="test-key")
|
||||
assert provider2.base_url == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_openrouter_headers_set_correctly(self):
|
||||
"""Test that OpenRouter specific headers are set."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
# OpenRouter's actual domain should always be allowed
|
||||
provider.base_url = "https://openrouter.ai/api/v1"
|
||||
provider._validate_base_url() # Should not raise
|
||||
|
||||
def test_invalid_url_schemes_rejected(self):
|
||||
"""Test that non-HTTP(S) schemes are rejected."""
|
||||
|
||||
# Check default headers
|
||||
assert "HTTP-Referer" in provider.DEFAULT_HEADERS
|
||||
assert "X-Title" in provider.DEFAULT_HEADERS
|
||||
assert provider.DEFAULT_HEADERS["X-Title"] == "Zen MCP Server"
|
||||
|
||||
def test_openrouter_model_registry_initialized(self):
|
||||
"""Test that model registry is properly initialized."""
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
|
||||
invalid_urls = [
|
||||
"ftp://example.com/api",
|
||||
"file:///etc/passwd",
|
||||
"gopher://example.com",
|
||||
"javascript:alert(1)",
|
||||
]
|
||||
|
||||
for url in invalid_urls:
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
provider.base_url = url
|
||||
provider._validate_base_url()
|
||||
|
||||
# Registry should be initialized
|
||||
assert hasattr(provider, '_registry')
|
||||
assert provider._registry is not None
|
||||
|
||||
@@ -2,42 +2,34 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from providers.openrouter_registry import OpenRouterModelRegistry, OpenRouterModelConfig
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderType
|
||||
from providers.openrouter_registry import OpenRouterModelConfig, OpenRouterModelRegistry
|
||||
|
||||
|
||||
class TestOpenRouterModelRegistry:
|
||||
"""Test cases for OpenRouter model registry."""
|
||||
|
||||
|
||||
def test_registry_initialization(self):
|
||||
"""Test registry initializes with default config."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Should load models from default location
|
||||
assert len(registry.list_models()) > 0
|
||||
assert len(registry.list_aliases()) > 0
|
||||
|
||||
|
||||
def test_custom_config_path(self):
|
||||
"""Test registry with custom config path."""
|
||||
# Create temporary config
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["test1", "t1"],
|
||||
"context_window": 4096
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
config_data = {"models": [{"model_name": "test/model-1", "aliases": ["test1", "t1"], "context_window": 4096}]}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
|
||||
try:
|
||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||
assert len(registry.list_models()) == 1
|
||||
@@ -46,48 +38,40 @@ class TestOpenRouterModelRegistry:
|
||||
assert "t1" in registry.list_aliases()
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
def test_environment_variable_override(self):
|
||||
"""Test OPENROUTER_MODELS_PATH environment variable."""
|
||||
# Create custom config
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "env/model",
|
||||
"aliases": ["envtest"],
|
||||
"context_window": 8192
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
config_data = {"models": [{"model_name": "env/model", "aliases": ["envtest"], "context_window": 8192}]}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
|
||||
try:
|
||||
# Set environment variable
|
||||
original_env = os.environ.get('OPENROUTER_MODELS_PATH')
|
||||
os.environ['OPENROUTER_MODELS_PATH'] = temp_path
|
||||
|
||||
original_env = os.environ.get("OPENROUTER_MODELS_PATH")
|
||||
os.environ["OPENROUTER_MODELS_PATH"] = temp_path
|
||||
|
||||
# Create registry without explicit path
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Should load from environment path
|
||||
assert "env/model" in registry.list_models()
|
||||
assert "envtest" in registry.list_aliases()
|
||||
|
||||
|
||||
finally:
|
||||
# Restore environment
|
||||
if original_env is not None:
|
||||
os.environ['OPENROUTER_MODELS_PATH'] = original_env
|
||||
os.environ["OPENROUTER_MODELS_PATH"] = original_env
|
||||
else:
|
||||
del os.environ['OPENROUTER_MODELS_PATH']
|
||||
del os.environ["OPENROUTER_MODELS_PATH"]
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
def test_alias_resolution(self):
|
||||
"""Test alias resolution functionality."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Test various aliases
|
||||
test_cases = [
|
||||
("opus", "anthropic/claude-3-opus"),
|
||||
@@ -97,75 +81,71 @@ class TestOpenRouterModelRegistry:
|
||||
("4o", "openai/gpt-4o"),
|
||||
("mistral", "mistral/mistral-large"),
|
||||
]
|
||||
|
||||
|
||||
for alias, expected_model in test_cases:
|
||||
config = registry.resolve(alias)
|
||||
assert config is not None, f"Failed to resolve alias '{alias}'"
|
||||
assert config.model_name == expected_model
|
||||
|
||||
|
||||
def test_direct_model_name_lookup(self):
|
||||
"""Test looking up models by their full name."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Should be able to look up by full model name
|
||||
config = registry.resolve("anthropic/claude-3-opus")
|
||||
assert config is not None
|
||||
assert config.model_name == "anthropic/claude-3-opus"
|
||||
|
||||
|
||||
config = registry.resolve("openai/gpt-4o")
|
||||
assert config is not None
|
||||
assert config.model_name == "openai/gpt-4o"
|
||||
|
||||
|
||||
def test_unknown_model_resolution(self):
|
||||
"""Test resolution of unknown models."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
# Unknown aliases should return None
|
||||
assert registry.resolve("unknown-alias") is None
|
||||
assert registry.resolve("") is None
|
||||
assert registry.resolve("non-existent") is None
|
||||
|
||||
|
||||
def test_model_capabilities_conversion(self):
|
||||
"""Test conversion to ModelCapabilities."""
|
||||
registry = OpenRouterModelRegistry()
|
||||
|
||||
|
||||
config = registry.resolve("opus")
|
||||
assert config is not None
|
||||
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.provider == ProviderType.OPENROUTER
|
||||
assert caps.model_name == "anthropic/claude-3-opus"
|
||||
assert caps.friendly_name == "OpenRouter"
|
||||
assert caps.max_tokens == 200000
|
||||
assert not caps.supports_extended_thinking
|
||||
|
||||
|
||||
def test_duplicate_alias_detection(self):
|
||||
"""Test that duplicate aliases are detected."""
|
||||
config_data = {
|
||||
"models": [
|
||||
{
|
||||
"model_name": "test/model-1",
|
||||
"aliases": ["dupe"],
|
||||
"context_window": 4096
|
||||
},
|
||||
{"model_name": "test/model-1", "aliases": ["dupe"], "context_window": 4096},
|
||||
{
|
||||
"model_name": "test/model-2",
|
||||
"aliases": ["DUPE"], # Same alias, different case
|
||||
"context_window": 8192
|
||||
}
|
||||
"context_window": 8192,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="Duplicate alias"):
|
||||
OpenRouterModelRegistry(config_path=temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
def test_backwards_compatibility_max_tokens(self):
|
||||
"""Test backwards compatibility with old max_tokens field."""
|
||||
config_data = {
|
||||
@@ -174,44 +154,44 @@ class TestOpenRouterModelRegistry:
|
||||
"model_name": "test/old-model",
|
||||
"aliases": ["old"],
|
||||
"max_tokens": 16384, # Old field name
|
||||
"supports_extended_thinking": False
|
||||
"supports_extended_thinking": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
|
||||
try:
|
||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||
config = registry.resolve("old")
|
||||
|
||||
|
||||
assert config is not None
|
||||
assert config.context_window == 16384 # Should be converted
|
||||
|
||||
|
||||
# Check capabilities still work
|
||||
caps = config.to_capabilities()
|
||||
assert caps.max_tokens == 16384
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
def test_missing_config_file(self):
|
||||
"""Test behavior with missing config file."""
|
||||
# Use a non-existent path
|
||||
registry = OpenRouterModelRegistry(config_path="/non/existent/path.json")
|
||||
|
||||
|
||||
# Should initialize with empty maps
|
||||
assert len(registry.list_models()) == 0
|
||||
assert len(registry.list_aliases()) == 0
|
||||
assert registry.resolve("anything") is None
|
||||
|
||||
|
||||
def test_invalid_json_config(self):
|
||||
"""Test handling of invalid JSON."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
f.write("{ invalid json }")
|
||||
temp_path = f.name
|
||||
|
||||
|
||||
try:
|
||||
registry = OpenRouterModelRegistry(config_path=temp_path)
|
||||
# Should handle gracefully and initialize empty
|
||||
@@ -219,7 +199,7 @@ class TestOpenRouterModelRegistry:
|
||||
assert len(registry.list_aliases()) == 0
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
def test_model_with_all_capabilities(self):
|
||||
"""Test model with all capability flags."""
|
||||
config = OpenRouterModelConfig(
|
||||
@@ -231,13 +211,13 @@ class TestOpenRouterModelRegistry:
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
supports_json_mode=True,
|
||||
description="Fully featured test model"
|
||||
description="Fully featured test model",
|
||||
)
|
||||
|
||||
|
||||
caps = config.to_capabilities()
|
||||
assert caps.max_tokens == 128000
|
||||
assert caps.supports_extended_thinking
|
||||
assert caps.supports_system_prompts
|
||||
assert caps.supports_streaming
|
||||
assert caps.supports_function_calling
|
||||
# Note: supports_json_mode is not in ModelCapabilities yet
|
||||
# Note: supports_json_mode is not in ModelCapabilities yet
|
||||
|
||||
Reference in New Issue
Block a user