refactor: cleanup provider base class; cleanup shared responsibilities; cleanup public contract

docs: document provider base class
refactor: cleanup custom provider, it should only deal with `is_custom` model configurations
fix: make sure openrouter provider does not load `is_custom` models
fix: listmodels tool cleanup
This commit is contained in:
Fahad
2025-10-02 12:59:45 +04:00
parent 6ec2033f34
commit 693b84db2b
15 changed files with 509 additions and 751 deletions

View File

@@ -61,108 +61,52 @@ class OpenRouterProvider(OpenAICompatibleProvider):
aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model aliases to OpenRouter model names.
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
Args:
model_name: Input model name or alias
Returns:
Resolved OpenRouter model name
"""
# Try to resolve through registry
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
else:
# If not found in registry, return as-is
# This allows using models not in our config file
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Get capabilities for a model.
Args:
model_name: Name of the model (or alias)
Returns:
ModelCapabilities from registry or generic defaults
"""
# Try to get from registry first
capabilities = self._registry.get_capabilities(model_name)
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
return capabilities
else:
# Resolve any potential aliases and create generic capabilities
resolved_name = self._resolve_model_name(model_name)
logging.debug(
f"Using generic capabilities for '{resolved_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities."
)
logging.debug(
f"Using generic capabilities for '{canonical_name}' via OpenRouter. "
"Consider adding to custom_models.json for specific capabilities."
)
# Create generic capabilities with conservative defaults
capabilities = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=resolved_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768, # Conservative default context window
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic._is_generic = True
return generic
# Mark as generic for validation purposes
capabilities._is_generic = True
return capabilities
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
"""Identify this provider for restrictions and logging."""
return ProviderType.OPENROUTER
def validate_model_name(self, model_name: str) -> bool:
"""Validate if the model name is allowed.
As the catch-all provider, OpenRouter accepts any model name that wasn't
handled by higher-priority providers. OpenRouter will validate based on
the API key's permissions and local restrictions.
Args:
model_name: Model name to validate
Returns:
True if model is allowed, False if restricted
"""
# Check model restrictions if configured
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
# Check if model name itself is allowed
if restriction_service.is_allowed(self.get_provider_type(), model_name):
return True
# Also check aliases - model_name might be an alias
model_config = self._registry.resolve(model_name)
if model_config and model_config.aliases:
for alias in model_config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
return True
# If restrictions are configured and model/alias not in allowed list, reject
return False
# No restrictions configured - accept any model name as the fallback provider
return True
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
@@ -204,6 +148,10 @@ class OpenRouterProvider(OpenAICompatibleProvider):
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models(
self,
*,
@@ -227,6 +175,12 @@ class OpenRouterProvider(OpenAICompatibleProvider):
if not config:
continue
# Custom models belong to CustomProvider; skip them here so the two
# providers don't race over the same registrations (important for tests
# that stub the registry with minimal objects lacking attrs).
if hasattr(config, "is_custom") and config.is_custom is True:
continue
if restriction_service:
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
@@ -255,24 +209,37 @@ class OpenRouterProvider(OpenAICompatibleProvider):
unique=unique,
)
def get_model_configurations(self) -> dict[str, ModelCapabilities]:
"""Get model configurations from the registry.
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
For OpenRouter, we convert registry configurations to ModelCapabilities objects.
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry."""
Returns:
Dictionary mapping model names to their ModelCapabilities objects
"""
configs = {}
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
if self._registry:
# Get all models from registry
for model_name in self._registry.list_models():
# Only include models that this provider validates
if self.validate_model_name(model_name):
config = self._registry.resolve(model_name)
if config and not config.is_custom: # Only OpenRouter models, not custom ones
# Use ModelCapabilities directly from registry
configs[model_name] = config
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
return configs
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed OpenRouter capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# See note in list_models: respect the CustomProvider boundary.
if hasattr(config, "is_custom") and config.is_custom is True:
continue
capabilities[model_name] = config
return capabilities