refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -54,11 +54,9 @@ class TestCustomProvider:
provider = CustomProvider(api_key="test-key", base_url="http://localhost:11434/v1")
# Test with a model that should be in the registry (OpenRouter model)
capabilities = provider.get_capabilities("o3") # o3 is an OpenRouter model
assert capabilities.provider == ProviderType.OPENROUTER # o3 is an OpenRouter model (is_custom=false)
assert capabilities.context_window > 0
# OpenRouter-backed models should be handled by the OpenRouter provider
with pytest.raises(ValueError):
provider.get_capabilities("o3")
# Test with a custom model (is_custom=true)
capabilities = provider.get_capabilities("local-llama")
@@ -168,7 +166,13 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "llama,anthropic/claude-opus-4.1",
},
clear=True,
):
# Register both providers
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
@@ -195,18 +199,22 @@ class TestCustomProviderRegistration:
return CustomProvider(api_key="", base_url="http://localhost:11434/v1")
with patch.dict(
os.environ, {"OPENROUTER_API_KEY": "test-openrouter-key", "CUSTOM_API_PLACEHOLDER": "configured"}
os.environ,
{
"OPENROUTER_API_KEY": "test-openrouter-key",
"CUSTOM_API_PLACEHOLDER": "configured",
"OPENROUTER_ALLOWED_MODELS": "",
},
clear=True,
):
# Register OpenRouter first (higher priority)
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
import utils.model_restrictions
# Test model resolution - OpenRouter should win for shared aliases
provider_for_model = ModelProviderRegistry.get_provider_for_model("llama")
utils.model_restrictions._restriction_service = None
custom_provider = custom_provider_factory()
openrouter_provider = OpenRouterProvider(api_key="test-openrouter-key")
# OpenRouter should be selected first due to registration order
assert provider_for_model is not None
# The exact provider type depends on which validates the model first
assert not custom_provider.validate_model_name("llama")
assert openrouter_provider.validate_model_name("llama")
class TestConfigureProvidersFunction:

View File

@@ -121,7 +121,7 @@ class TestDIALProvider:
"""Test that get_capabilities raises for invalid models."""
provider = DIALModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported DIAL model"):
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider dial"):
provider.get_capabilities("invalid-model")
@patch("utils.model_restrictions.get_restriction_service")

View File

@@ -356,15 +356,13 @@ class TestCustomProviderOpenRouterRestrictions:
provider = CustomProvider(base_url="http://test.com/v1")
# For OpenRouter models, get_capabilities should still work but mark them as OPENROUTER
# This tests the capabilities lookup, not validation
capabilities = provider.get_capabilities("opus")
assert capabilities.provider == ProviderType.OPENROUTER
# For OpenRouter models, CustomProvider should defer by raising
with pytest.raises(ValueError):
provider.get_capabilities("opus")
# Should raise for disallowed OpenRouter model
with pytest.raises(ValueError) as exc_info:
# Should raise for disallowed OpenRouter model (still defers)
with pytest.raises(ValueError):
provider.get_capabilities("haiku")
assert "not allowed by restriction policy" in str(exc_info.value)
# Should still work for custom models (is_custom=true)
capabilities = provider.get_capabilities("local-llama")

View File

@@ -141,7 +141,7 @@ class TestXAIProvider:
"""Test error handling for unsupported models."""
provider = XAIModelProvider("test-key")
with pytest.raises(ValueError, match="Unsupported X.AI model"):
with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider xai"):
provider.get_capabilities("invalid-model")
def test_extended_thinking_flags(self):